diff --git a/queries/reflect.go b/queries/reflect.go index 153d567c..855c5f90 100644 --- a/queries/reflect.go +++ b/queries/reflect.go @@ -115,12 +115,12 @@ func (q *Query) BindGP(ctx context.Context, obj any) { // For custom objects that want to use eager loading, please see the // loadRelationships function. func Bind(rows *sql.Rows, obj any) error { - structType, sliceType, singular, err := bindChecks(obj) + structType, _, singular, err := bindChecks(obj) if err != nil { return err } - return bind(rows, obj, structType, sliceType, singular) + return bind(rows, obj, structType, singular) } // Bind executes the query and inserts the @@ -133,7 +133,7 @@ func Bind(rows *sql.Rows, obj any) error { // // Also see documentation for Bind() func (q *Query) Bind(ctx context.Context, exec boil.Executor, obj any) error { - structType, sliceType, bkind, err := bindChecks(obj) + structType, _, bkind, err := bindChecks(obj) if err != nil { return err } @@ -147,7 +147,7 @@ func (q *Query) Bind(ctx context.Context, exec boil.Executor, obj any) error { if err != nil { return errors.Wrap(err, "bind failed to execute query") } - if err = bind(rows, obj, structType, sliceType, bkind); err != nil { + if err = bind(rows, obj, structType, bkind); err != nil { if innerErr := rows.Close(); innerErr != nil { return errors.Wrapf(err, "error on rows.Close after bind error: %+v", innerErr) } @@ -223,7 +223,7 @@ func bindChecks(obj any) (structType reflect.Type, sliceType reflect.Type, bkind } } -func bind(rows *sql.Rows, obj any, structType, sliceType reflect.Type, bkind bindKind) error { +func bind(rows *sql.Rows, obj any, structType reflect.Type, bkind bindKind) error { cols, err := rows.Columns() if err != nil { return errors.Wrap(err, "bind failed to get column names") @@ -240,11 +240,6 @@ func bind(rows *sql.Rows, obj any, structType, sliceType reflect.Type, bkind bin return err } - var oneStruct reflect.Value - if bkind == kindSliceStruct { - oneStruct = reflect.Indirect(reflect.New(structType)) - } - foundOne := false Rows: for rows.Next() { @@ -256,14 +251,12 @@ Rows: case kindStruct: pointers = PtrsFromMapping(reflect.Indirect(reflect.ValueOf(obj)), mapping) case kindSliceStruct: - pointers = PtrsFromMapping(oneStruct, mapping) + newStruct = reflect.Indirect(reflect.New(structType)) + pointers = PtrsFromMapping(newStruct, mapping) case kindPtrSliceStruct: newStruct = makeStructPtr(structType) pointers = PtrsFromMapping(reflect.Indirect(newStruct), mapping) } - if err != nil { - return err - } if err := rows.Scan(pointers...); err != nil { return errors.Wrap(err, "failed to bind pointers to obj") @@ -272,9 +265,7 @@ Rows: switch bkind { case kindStruct: break Rows - case kindSliceStruct: - ptrSlice.Set(reflect.Append(ptrSlice, oneStruct)) - case kindPtrSliceStruct: + case kindSliceStruct, kindPtrSliceStruct: ptrSlice.Set(reflect.Append(ptrSlice, newStruct)) } } @@ -871,4 +862,4 @@ func unTitleCase(n string) string { ret := buf.String() strmangle.PutBuffer(buf) return ret -} \ No newline at end of file +} diff --git a/queries/reflect_test.go b/queries/reflect_test.go index 0c82df28..a58bc535 100644 --- a/queries/reflect_test.go +++ b/queries/reflect_test.go @@ -14,6 +14,7 @@ import ( "github.com/aarondl/null/v8" "github.com/aarondl/sqlboiler/v4/drivers" + "github.com/aarondl/sqlboiler/v4/types" "github.com/DATA-DOG/go-sqlmock" ) @@ -181,6 +182,57 @@ func TestBindPtrSlice(t *testing.T) { } } +func TestBindJsonSlice(t *testing.T) { + t.Parallel() + + type siteInfoItem struct { + ID int `boil:"id" json:"id" toml:"id" yaml:"id"` + Fields types.JSON `boil:"test" json:"test" toml:"test" yaml:"test"` + } + + query := &Query{ + from: []string{"fun"}, + dialect: &drivers.Dialect{LQ: '"', RQ: '"', UseIndexPlaceholders: true}, + } + + db, mock, err := sqlmock.New() + if err != nil { + t.Error(err) + } + + ret := sqlmock.NewRows([]string{"id", "test"}) + ret.AddRow(driver.Value(int64(35)), driver.Value(`{"foo": "bar"}`)) + ret.AddRow(driver.Value(int64(12)), driver.Value("{}")) + mock.ExpectQuery(`SELECT \* FROM "fun";`).WillReturnRows(ret) + + var testResults []siteInfoItem + err = query.Bind(nil, db, &testResults) + if err != nil { + t.Error(err) + } + + if len(testResults) != 2 { + t.Fatal("wrong number of results:", len(testResults)) + } + if id := testResults[0].ID; id != 35 { + t.Error("wrong ID:", id) + } + if name := testResults[0].Fields; name.String() != `{"foo": "bar"}` { + t.Error("wrong name:", name) + } + + if id := testResults[1].ID; id != 12 { + t.Error("wrong ID:", id) + } + if name := testResults[1].Fields; name.String() != "{}" { + t.Error("wrong name:", name) + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Error(err) + } +} + func testMakeMapping(byt ...byte) uint64 { var x uint64 for i, b := range byt {