Skip to content

Commit 25251b1

Browse files
committed
Prevent panic in compileMetadata() when final func is not defined for an aggregate
Previously, when there was a user defined aggregate without FINAL_FUNC defined the compileMetadata() function would panic due to nil pointer dereference. Gocql should properly handle this case since FINAL_FUNC is optional for user defined aggregates. This patch fixes the described problem by adding nil-checking before dereferencing. Patch by Bohdan Siryk; Reviewed by <> for CASSGO-105
1 parent f1e31a5 commit 25251b1

File tree

3 files changed

+169
-2
lines changed

3 files changed

+169
-2
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1818
- Return correct values from RowData (CASSGO-95)
1919
- Prevent setting a compression flag in a frame header when native proto v5 is being used (CASSGO-98)
2020
- Use protocol downgrading approach during protocol negotiation (CASSGO-97)
21+
- Prevent panic iin compileMetadata() when final func is not defined for an aggregate (CASSGO-105)
2122

2223
## [2.0.0]
2324

metadata.go

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -359,8 +359,14 @@ func compileMetadata(
359359
}
360360
keyspace.Aggregates = make(map[string]*AggregateMetadata, len(aggregates))
361361
for i, _ := range aggregates {
362-
aggregates[i].FinalFunc = *keyspace.Functions[aggregates[i].finalFunc]
363-
aggregates[i].StateFunc = *keyspace.Functions[aggregates[i].stateFunc]
362+
finalFunc := keyspace.Functions[aggregates[i].finalFunc]
363+
if finalFunc != nil {
364+
aggregates[i].FinalFunc = *finalFunc
365+
}
366+
stateFunc := keyspace.Functions[aggregates[i].stateFunc]
367+
if stateFunc != nil {
368+
aggregates[i].StateFunc = *stateFunc
369+
}
364370
keyspace.Aggregates[aggregates[i].Name] = &aggregates[i]
365371
}
366372
keyspace.UserTypes = make(map[string]*UserTypeMetadata, len(uTypes))

metadata_test.go

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ package gocql
3434
import (
3535
"strconv"
3636
"testing"
37+
38+
"github.com/stretchr/testify/require"
3739
)
3840

3941
// Tests V1 and V2 metadata "compilation" from example data which might be returned
@@ -596,3 +598,161 @@ func assertParseNonCompositeTypes(
596598
}
597599
}
598600
}
601+
602+
func TestCompileMetadataWithFunctions(t *testing.T) {
603+
session := &Session{
604+
cfg: ClusterConfig{
605+
ProtoVersion: protoVersion5,
606+
Logger: NewLogger(LogLevelInfo),
607+
},
608+
types: GlobalTypes.Copy(),
609+
}
610+
611+
keyspace := &KeyspaceMetadata{
612+
Name: "test_keyspace",
613+
}
614+
615+
functions := []FunctionMetadata{
616+
{
617+
Keyspace: "test_keyspace",
618+
Name: "test_func",
619+
ArgumentTypes: []TypeInfo{intTypeInfo{}},
620+
ArgumentNames: []string{"arg1"},
621+
Body: "return arg1 + 1;",
622+
CalledOnNullInput: false,
623+
},
624+
{
625+
Keyspace: "test_keyspace",
626+
Name: "test_func_no_args",
627+
ArgumentTypes: []TypeInfo{},
628+
ArgumentNames: []string{},
629+
Body: "return 1;",
630+
CalledOnNullInput: false,
631+
},
632+
{
633+
Keyspace: "test_keyspace",
634+
Name: "test_func_null_input",
635+
ArgumentTypes: []TypeInfo{intTypeInfo{}},
636+
ArgumentNames: []string{"arg1"},
637+
Body: "if (arg1 == null) return 0; else return arg1;",
638+
CalledOnNullInput: true,
639+
},
640+
}
641+
642+
compileMetadata(session, keyspace, nil, nil, functions, nil, nil, nil)
643+
644+
require.Len(t, keyspace.Functions, 3, "Expected to have 3 functions")
645+
require.Contains(t, keyspace.Functions, "test_func")
646+
require.Contains(t, keyspace.Functions, "test_func_no_args")
647+
require.Contains(t, keyspace.Functions, "test_func_null_input")
648+
649+
testFunc := keyspace.Functions["test_func"]
650+
require.Equal(t, "test_func", testFunc.Name)
651+
require.Equal(t, 1, len(testFunc.ArgumentTypes))
652+
require.Equal(t, TypeInt, testFunc.ArgumentTypes[0].Type())
653+
require.Len(t, len(testFunc.ArgumentNames), 1)
654+
require.Equal(t, "arg1", testFunc.ArgumentNames[0])
655+
require.Equal(t, "return arg1 + 1;", testFunc.Body)
656+
require.False(t, testFunc.CalledOnNullInput)
657+
658+
testFuncNoArgs := keyspace.Functions["test_func_no_args"]
659+
require.Equal(t, "test_func_no_args", testFuncNoArgs.Name)
660+
require.Empty(t, testFuncNoArgs.ArgumentTypes)
661+
require.Empty(t, testFuncNoArgs.ArgumentNames)
662+
require.Equal(t, "return 1;", testFuncNoArgs.Body)
663+
require.False(t, testFuncNoArgs.CalledOnNullInput)
664+
665+
testFuncNullInput := keyspace.Functions["test_func_null_input"]
666+
require.Equal(t, "test_func_null_input", testFuncNullInput.Name)
667+
require.Len(t, testFuncNullInput.ArgumentTypes, 1)
668+
require.Equal(t, TypeInt, testFuncNullInput.ArgumentTypes[0].Type())
669+
require.Len(t, testFuncNullInput.ArgumentNames, 1)
670+
require.Equal(t, "arg1", testFuncNullInput.ArgumentNames[0])
671+
require.Equal(t, "if (arg1 == null) return 0; else return arg1;", testFuncNullInput.Body)
672+
require.True(t, testFuncNullInput.CalledOnNullInput)
673+
}
674+
675+
func TestCompileMetadataWithAggregates(t *testing.T) {
676+
session := &Session{
677+
cfg: ClusterConfig{
678+
ProtoVersion: protoVersion5,
679+
Logger: NewLogger(LogLevelInfo),
680+
},
681+
types: GlobalTypes.Copy(),
682+
}
683+
684+
keyspace := &KeyspaceMetadata{
685+
Name: "test_keyspace",
686+
}
687+
688+
functions := []FunctionMetadata{
689+
{
690+
Keyspace: "test_keyspace",
691+
Name: "test_state_func",
692+
ArgumentTypes: []TypeInfo{intTypeInfo{}},
693+
ArgumentNames: []string{"arg1"},
694+
Body: "return arg1 + 1;",
695+
CalledOnNullInput: false,
696+
},
697+
{
698+
Keyspace: "test_keyspace",
699+
Name: "test_final_func",
700+
ArgumentTypes: []TypeInfo{floatTypeInfo{}},
701+
ArgumentNames: []string{"arg1"},
702+
Body: "return arg1 + 1;",
703+
CalledOnNullInput: false,
704+
},
705+
}
706+
707+
aggregates := []AggregateMetadata{
708+
{
709+
Keyspace: "test_keyspace",
710+
Name: "test_agg",
711+
ArgumentTypes: []TypeInfo{
712+
intTypeInfo{},
713+
},
714+
InitCond: "0",
715+
StateFunc: functions[0],
716+
FinalFunc: functions[1],
717+
ReturnType: intTypeInfo{},
718+
StateType: intTypeInfo{},
719+
stateFunc: "test_state_func",
720+
finalFunc: "test_final_func",
721+
},
722+
{
723+
Keyspace: "test_keyspace",
724+
Name: "test_agg_no_final_func",
725+
ArgumentTypes: []TypeInfo{
726+
doubleTypeInfo{},
727+
},
728+
InitCond: "0",
729+
StateFunc: functions[0],
730+
ReturnType: doubleTypeInfo{},
731+
StateType: doubleTypeInfo{},
732+
stateFunc: "test_state_func",
733+
finalFunc: "",
734+
},
735+
}
736+
737+
compileMetadata(session, keyspace, nil, nil, functions, aggregates, nil, nil)
738+
739+
require.Len(t, keyspace.Aggregates, 2, "Expected to have 2 aggregates")
740+
require.Contains(t, keyspace.Aggregates, "test_agg")
741+
require.Contains(t, keyspace.Aggregates, "test_agg_no_final_func")
742+
743+
testAgg := keyspace.Aggregates["test_agg"]
744+
require.Equal(t, "test_agg", testAgg.Name)
745+
require.Len(t, testAgg.ArgumentTypes, 1)
746+
require.Equal(t, TypeInt, testAgg.ArgumentTypes[0].Type())
747+
require.Equal(t, "0", testAgg.InitCond)
748+
require.Equal(t, TypeInt, testAgg.ReturnType.Type())
749+
require.Equal(t, TypeInt, testAgg.StateType.Type())
750+
751+
testAggNoFinalFunc := keyspace.Aggregates["test_agg_no_final_func"]
752+
require.Equal(t, "test_agg_no_final_func", testAggNoFinalFunc.Name)
753+
require.Len(t, testAggNoFinalFunc.ArgumentTypes, 1)
754+
require.Equal(t, TypeDouble, testAggNoFinalFunc.ArgumentTypes[0].Type())
755+
require.Equal(t, "0", testAggNoFinalFunc.InitCond)
756+
require.Equal(t, TypeDouble, testAggNoFinalFunc.ReturnType.Type())
757+
require.Equal(t, TypeDouble, testAggNoFinalFunc.StateType.Type())
758+
}

0 commit comments

Comments
 (0)