Skip to content

Commit 7f83407

Browse files
Optimize evalPopulationStdDevDecimal with Welford's algorithm
Co-authored-by: suyashkumar <6299853+suyashkumar@users.noreply.github.com>
1 parent 8a5a5fa commit 7f83407

File tree

3 files changed

+147
-28
lines changed

3 files changed

+147
-28
lines changed

interpreter/operator_aggregate.go

Lines changed: 16 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1038,32 +1038,10 @@ func (i *interpreter) evalPopulationStdDevDecimal(m model.IUnaryExpression, oper
10381038
return result.Value{}, err
10391039
}
10401040

1041-
countValue, err := i.evalCount(m, operand)
1042-
if err != nil {
1043-
return result.Value{}, err
1044-
}
1045-
if result.IsNull(countValue) {
1046-
return result.New(nil)
1047-
}
1048-
count, err := result.ToInt32(countValue)
1049-
if err != nil {
1050-
return result.Value{}, err
1051-
}
1052-
if count == 0 {
1053-
return result.New(nil)
1054-
}
1055-
meanValue, err := i.evalAvg(m, operand)
1056-
if err != nil {
1057-
return result.Value{}, err
1058-
}
1059-
if result.IsNull(meanValue) {
1060-
return result.New(nil)
1061-
}
1062-
mean, err := result.ToFloat64(meanValue)
1063-
if err != nil {
1064-
return result.Value{}, err
1065-
}
1066-
var sum float64
1041+
var count float64
1042+
var mean float64
1043+
var m2 float64
1044+
10671045
for _, elem := range l {
10681046
if result.IsNull(elem) {
10691047
continue
@@ -1072,10 +1050,20 @@ func (i *interpreter) evalPopulationStdDevDecimal(m model.IUnaryExpression, oper
10721050
if err != nil {
10731051
return result.Value{}, err
10741052
}
1075-
sum += (v - mean) * (v - mean)
1053+
1054+
count++
1055+
delta := v - mean
1056+
mean += delta / count
1057+
delta2 := v - mean
1058+
m2 += delta * delta2
1059+
}
1060+
1061+
if count == 0 {
1062+
return result.New(nil)
10761063
}
1064+
10771065
// Round to 8 decimal places to match CQL expected precision
1078-
stdDev := math.Sqrt(sum / float64(count))
1066+
stdDev := math.Sqrt(m2 / count)
10791067
roundedStdDev := math.Round(stdDev*100000000) / 100000000
10801068
return result.New(roundedStdDev)
10811069
}
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
package enginetests
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"strings"
7+
"testing"
8+
9+
"github.com/google/cql/interpreter"
10+
"github.com/google/cql/parser"
11+
"github.com/google/cql/result"
12+
)
13+
14+
func BenchmarkPopulationStdDev(b *testing.B) {
15+
// Create a large list of decimals
16+
count := 10000
17+
values := make([]string, count)
18+
for i := 0; i < count; i++ {
19+
values[i] = fmt.Sprintf("%d.0", i%100)
20+
}
21+
listCQL := "{" + strings.Join(values, ", ") + "}"
22+
cql := "PopulationStdDev(" + listCQL + ")"
23+
24+
p := newFHIRParser(b)
25+
parsedLibs, err := p.Libraries(context.Background(), wrapInLib(b, cql), parser.Config{})
26+
if err != nil {
27+
b.Fatalf("Parse Libraries returned unexpected error: %v", err)
28+
}
29+
30+
config := interpreter.Config{
31+
DataModels: p.DataModel(),
32+
Retriever: BuildRetriever(b),
33+
Terminology: buildTerminologyProvider(b),
34+
EvaluationTimestamp: defaultEvalTimestamp,
35+
ReturnPrivateDefs: true,
36+
}
37+
38+
b.ResetTimer()
39+
b.Run("LargeList", func(b *testing.B) {
40+
var force result.Libraries
41+
for n := 0; n < b.N; n++ {
42+
force, err = interpreter.Eval(context.Background(), parsedLibs, config)
43+
if err != nil {
44+
b.Fatalf("Eval returned unexpected error: %v", err)
45+
}
46+
}
47+
forceBenchResult = force
48+
})
49+
}

tests/enginetests/operator_aggregate_test.go

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1387,3 +1387,85 @@ func TestMode(t *testing.T) {
13871387
})
13881388
}
13891389
}
1390+
1391+
func TestPopulationStdDev(t *testing.T) {
1392+
tests := []struct {
1393+
name string
1394+
cql string
1395+
wantModel model.IExpression
1396+
wantResult result.Value
1397+
}{
1398+
{
1399+
name: "PopulationStdDev({1.0, 2.0, 3.0, 4.0, 5.0})",
1400+
cql: "PopulationStdDev({1.0, 2.0, 3.0, 4.0, 5.0})",
1401+
wantModel: &model.PopulationStdDev{
1402+
UnaryExpression: &model.UnaryExpression{
1403+
Operand: model.NewList([]string{"1.0", "2.0", "3.0", "4.0", "5.0"}, types.Decimal),
1404+
Expression: model.ResultType(types.Decimal),
1405+
},
1406+
},
1407+
wantResult: newOrFatal(t, 1.41421356),
1408+
},
1409+
{
1410+
name: "PopulationStdDev with unordered decimal list",
1411+
cql: "PopulationStdDev({5.0, 2.0, 1.0, 4.0, 3.0})",
1412+
wantResult: newOrFatal(t, 1.41421356),
1413+
},
1414+
{
1415+
name: "PopulationStdDev with all identical values",
1416+
cql: "PopulationStdDev({3.0, 3.0, 3.0, 3.0})",
1417+
wantResult: newOrFatal(t, 0.0),
1418+
},
1419+
{
1420+
name: "PopulationStdDev with null input",
1421+
cql: "PopulationStdDev(null as List<Decimal>)",
1422+
wantResult: newOrFatal(t, nil),
1423+
},
1424+
{
1425+
name: "PopulationStdDev with empty list",
1426+
cql: "PopulationStdDev({} as List<Decimal>)",
1427+
wantResult: newOrFatal(t, nil),
1428+
},
1429+
{
1430+
name: "PopulationStdDev with single value",
1431+
cql: "PopulationStdDev({5.0})",
1432+
wantResult: newOrFatal(t, 0.0),
1433+
},
1434+
{
1435+
name: "PopulationStdDev with null values in list",
1436+
cql: "PopulationStdDev({1.0, null, 3.0, null, 5.0})",
1437+
wantResult: newOrFatal(t, 1.63299316),
1438+
},
1439+
{
1440+
name: "PopulationStdDev with all null values",
1441+
cql: "PopulationStdDev({null, null, null} as List<Decimal>)",
1442+
wantResult: newOrFatal(t, nil),
1443+
},
1444+
{
1445+
name: "PopulationStdDev with quantities",
1446+
cql: "PopulationStdDev({1.0 'g', 2.0 'g', 3.0 'g', 4.0 'g', 5.0 'g'})",
1447+
wantResult: newOrFatal(t, result.Quantity{Value: 1.4142135623730951, Unit: "g"}),
1448+
},
1449+
}
1450+
1451+
for _, tc := range tests {
1452+
t.Run(tc.name, func(t *testing.T) {
1453+
p := newFHIRParser(t)
1454+
parsedLibs, err := p.Libraries(context.Background(), wrapInLib(t, tc.cql), parser.Config{})
1455+
if err != nil {
1456+
t.Fatalf("Parse returned unexpected error: %v", err)
1457+
}
1458+
if diff := cmp.Diff(tc.wantModel, getTESTRESULTModel(t, parsedLibs)); tc.wantModel != nil && diff != "" {
1459+
t.Errorf("Parse diff (-want +got):\n%s", diff)
1460+
}
1461+
1462+
results, err := interpreter.Eval(context.Background(), parsedLibs, defaultInterpreterConfig(t, p))
1463+
if err != nil {
1464+
t.Fatalf("Eval returned unexpected error: %v", err)
1465+
}
1466+
if diff := cmp.Diff(tc.wantResult, getTESTRESULT(t, results), protocmp.Transform()); diff != "" {
1467+
t.Errorf("Eval diff (-want +got)\n%v", diff)
1468+
}
1469+
})
1470+
}
1471+
}

0 commit comments

Comments
 (0)