diff --git a/database/utils.go b/database/utils.go index 43ed574d..77b1762c 100644 --- a/database/utils.go +++ b/database/utils.go @@ -10,6 +10,8 @@ import ( "github.com/icinga/icinga-go-library/types" "github.com/jmoiron/sqlx" "github.com/pkg/errors" + "slices" + "strings" ) // CantPerformQuery wraps the given error with the specified query that cannot be executed. @@ -81,6 +83,19 @@ func InsertObtainID(ctx context.Context, conn TxOrDB, stmt string, arg any) (int return resultID, nil } +// BuildInsertStmtWithout builds an insert stmt without the provided columns. +func BuildInsertStmtWithout(db *DB, into interface{}, withoutColumns ...string) string { + columns := slices.DeleteFunc( + db.BuildColumns(into), + func(column string) bool { return slices.Contains(withoutColumns, column) }) + + return fmt.Sprintf( + `INSERT INTO "%s" ("%s") VALUES (%s)`, + TableName(into), strings.Join(columns, `", "`), + fmt.Sprintf(":%s", strings.Join(columns, ", :")), + ) +} + // unsafeSetSessionVariableIfExists sets the given MySQL/MariaDB system variable for the specified database session. // // NOTE: It is unsafe to use this function with untrusted/user supplied inputs and poses an SQL injection, diff --git a/types/int.go b/types/int.go index 448180f1..7189f2a9 100644 --- a/types/int.go +++ b/types/int.go @@ -14,6 +14,29 @@ type Int struct { sql.NullInt64 } +// TransformZeroIntToNull transforms a valid Int carrying a zero value to a SQL NULL. +func TransformZeroIntToNull(i *Int) { + if i.Valid && i.Int64 == 0 { + i.Valid = false + } +} + +// MakeInt constructs a new Int. +// +// Multiple transformer functions can be given, each transforming the generated Int, e.g., TransformZeroIntToNull. +func MakeInt(in int64, transformers ...func(*Int)) Int { + i := Int{sql.NullInt64{ + Int64: in, + Valid: true, + }} + + for _, transformer := range transformers { + transformer(&i) + } + + return i +} + // MarshalJSON implements the json.Marshaler interface. // Supports JSON null. func (i Int) MarshalJSON() ([]byte, error) { diff --git a/types/int_test.go b/types/int_test.go index d8a0506c..0e2bd3ba 100644 --- a/types/int_test.go +++ b/types/int_test.go @@ -6,6 +6,49 @@ import ( "testing" ) +func TestMakeInt(t *testing.T) { + subtests := []struct { + name string + input int64 + transformers []func(*Int) + output sql.NullInt64 + }{ + { + name: "zero", + input: 0, + output: sql.NullInt64{Int64: 0, Valid: true}, + }, + { + name: "positive", + input: 1, + output: sql.NullInt64{Int64: 1, Valid: true}, + }, + { + name: "negative", + input: -1, + output: sql.NullInt64{Int64: -1, Valid: true}, + }, + { + name: "zero-transform-zero-to-null", + input: 0, + transformers: []func(*Int){TransformZeroIntToNull}, + output: sql.NullInt64{Valid: false}, + }, + { + name: "positive-transform-zero-to-null", + input: 1, + transformers: []func(*Int){TransformZeroIntToNull}, + output: sql.NullInt64{Int64: 1, Valid: true}, + }, + } + + for _, st := range subtests { + t.Run(st.name, func(t *testing.T) { + require.Equal(t, Int{NullInt64: st.output}, MakeInt(st.input, st.transformers...)) + }) + } +} + func TestInt_MarshalJSON(t *testing.T) { subtests := []struct { name string diff --git a/types/string.go b/types/string.go index c01bf964..8b4160a4 100644 --- a/types/string.go +++ b/types/string.go @@ -14,12 +14,27 @@ type String struct { sql.NullString } -// MakeString constructs a new non-NULL String from s. -func MakeString(s string) String { - return String{sql.NullString{ - String: s, +// TransformEmptyStringToNull transforms a valid String carrying an empty text to a SQL NULL. +func TransformEmptyStringToNull(s *String) { + if s.Valid && s.String == "" { + s.Valid = false + } +} + +// MakeString constructs a new String. +// +// Multiple transformer functions can be given, each transforming the generated String, e.g., TransformEmptyStringToNull. +func MakeString(in string, transformers ...func(*String)) String { + s := String{sql.NullString{ + String: in, Valid: true, }} + + for _, transformer := range transformers { + transformer(&s) + } + + return s } // MarshalJSON implements the json.Marshaler interface. diff --git a/types/string_test.go b/types/string_test.go index 1ebc9927..2ea6fd35 100644 --- a/types/string_test.go +++ b/types/string_test.go @@ -9,18 +9,48 @@ import ( func TestMakeString(t *testing.T) { subtests := []struct { - name string - io string + name string + input string + transformers []func(*String) + output sql.NullString }{ - {"empty", ""}, - {"nul", "\x00"}, - {"space", " "}, - {"multiple", "abc"}, + { + name: "empty", + input: "", + output: sql.NullString{String: "", Valid: true}, + }, + { + name: "nul", + input: "\x00", + output: sql.NullString{String: "\x00", Valid: true}, + }, + { + name: "space", + input: " ", + output: sql.NullString{String: " ", Valid: true}, + }, + { + name: "valid-text", + input: "abc", + output: sql.NullString{String: "abc", Valid: true}, + }, + { + name: "empty-transform-empty-to-null", + input: "", + transformers: []func(*String){TransformEmptyStringToNull}, + output: sql.NullString{Valid: false}, + }, + { + name: "valid-text-transform-empty-to-null", + input: "abc", + transformers: []func(*String){TransformEmptyStringToNull}, + output: sql.NullString{String: "abc", Valid: true}, + }, } for _, st := range subtests { t.Run(st.name, func(t *testing.T) { - require.Equal(t, String{NullString: sql.NullString{String: st.io, Valid: true}}, MakeString(st.io)) + require.Equal(t, String{NullString: st.output}, MakeString(st.input, st.transformers...)) }) } } diff --git a/utils/utils.go b/utils/utils.go index 1fbfd156..5dd0ece6 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -1,6 +1,7 @@ package utils import ( + "cmp" "context" "crypto/sha1" // #nosec G505 -- Blocklisted import crypto/sha1 "fmt" @@ -8,9 +9,11 @@ import ( "github.com/lib/pq" "github.com/pkg/errors" "golang.org/x/exp/utf8string" + "iter" "net" "os" "path/filepath" + "slices" "strings" "time" ) @@ -163,3 +166,23 @@ func PrintErrorThenExit(err error, exitCode int) { fmt.Fprintln(os.Stderr, err) os.Exit(exitCode) } + +// IterateOrderedMap implements iter.Seq2 to iterate over a map in the key's order. +// +// This function returns a func yielding key-value-pairs from a given map in the order of their keys, if their type +// is cmp.Ordered. +func IterateOrderedMap[K cmp.Ordered, V any](m map[K]V) iter.Seq2[K, V] { + keys := make([]K, 0, len(m)) + for key := range m { + keys = append(keys, key) + } + slices.Sort(keys) + + return func(yield func(K, V) bool) { + for _, key := range keys { + if !yield(key, m[key]) { + return + } + } + } +} diff --git a/utils/utils_test.go b/utils/utils_test.go index b0ea54b8..202dcbad 100644 --- a/utils/utils_test.go +++ b/utils/utils_test.go @@ -1,6 +1,7 @@ package utils import ( + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "testing" ) @@ -52,3 +53,44 @@ func requireClosedEmpty(t *testing.T, ch <-chan int) { require.Fail(t, "receiving should not block") } } + +func TestIterateOrderedMap(t *testing.T) { + tests := []struct { + name string + in map[int]string + outKeys []int + }{ + {"empty", map[int]string{}, nil}, + {"single", map[int]string{1: "foo"}, []int{1}}, + {"few-numbers", map[int]string{1: "a", 2: "b", 3: "c"}, []int{1, 2, 3}}, + { + "1k-numbers", + func() map[int]string { + m := make(map[int]string) + for i := 0; i < 1000; i++ { + m[i] = "foo" + } + return m + }(), + func() []int { + keys := make([]int, 1000) + for i := 0; i < 1000; i++ { + keys[i] = i + } + return keys + }(), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var outKeys []int + for k, v := range IterateOrderedMap(tt.in) { + assert.Equal(t, tt.in[k], v) + outKeys = append(outKeys, k) + } + + assert.Equal(t, tt.outKeys, outKeys) + }) + } +}