@@ -35,16 +35,16 @@ func (f *fixturesLoader) prepareFieldValue(v any) any {
3535 return v
3636}
3737
38- func (f * fixturesLoader ) mssqlTableHasIdentityColumn (db * sql.DB , tableName string ) (bool , error ) {
39- row := db .QueryRow (`SELECT COUNT(*) FROM sys.identity_columns WHERE OBJECT_ID = OBJECT_ID(?)` , tableName )
38+ func (f * fixturesLoader ) mssqlTableHasIdentityColumn (q * sql.Tx , tableName string ) (bool , error ) {
39+ row := q .QueryRow (`SELECT COUNT(*) FROM sys.identity_columns WHERE OBJECT_ID = OBJECT_ID(?)` , tableName )
4040 var count int
4141 if err := row .Scan (& count ); err != nil {
4242 return false , err
4343 }
4444 return count > 0 , nil
4545}
4646
47- func (f * fixturesLoader ) loadFixtures (file string ) error {
47+ func (f * fixturesLoader ) loadFixtures (tx * sql. Tx , file string ) error {
4848 data , err := os .ReadFile (file )
4949 if err != nil {
5050 return fmt .Errorf ("failed to read file %q: %w" , file , err )
@@ -57,25 +57,14 @@ func (f *fixturesLoader) loadFixtures(file string) error {
5757
5858 tableName , _ , _ := strings .Cut (filepath .Base (file ), "." )
5959 tableNameQuoted := f .quoteObject (tableName )
60- _ , err = f . engine . Table ( tableName ). Where ( "1=1" ). Delete ( ) // sqlite3 doesn't support truncate
60+ _ , err = tx . Exec ( fmt . Sprintf ( "DELETE FROM %s" , tableNameQuoted ) ) // sqlite3 doesn't support truncate
6161 if err != nil {
6262 return err
6363 }
6464
65- goDB := f .engine .DB ().DB
66- tx , err := goDB .Begin ()
67- if err != nil {
68- return err
69- }
70- defer func () {
71- if tx != nil {
72- _ = tx .Rollback ()
73- }
74- }()
75-
7665 switch f .engine .Dialect ().URI ().DBType {
7766 case schemas .MSSQL :
78- hasIdentityColumn , err := f .mssqlTableHasIdentityColumn (goDB , tableName )
67+ hasIdentityColumn , err := f .mssqlTableHasIdentityColumn (tx , tableName )
7968 if err != nil {
8069 return err
8170 }
@@ -112,16 +101,20 @@ func (f *fixturesLoader) loadFixtures(file string) error {
112101 sqlBuf = sqlBuf [:0 ]
113102 sqlArguments = sqlArguments [:0 ]
114103 }
115- err = tx .Commit ()
116- tx = nil
117- return err
104+ return nil
118105}
119106
120107func (f * fixturesLoader ) Load () error {
108+ goDB := f .engine .DB ().DB
109+
121110 switch f .engine .Dialect ().URI ().DBType {
122111 case schemas .SQLITE :
123112 f .quoteObject = func (s string ) string { return fmt .Sprintf (`"%s"` , s ) }
124113 f .paramPlaceholder = func (idx int ) string { return "?" }
114+ if _ , err := goDB .Exec ("PRAGMA defer_foreign_keys = ON" ); err != nil {
115+ return err
116+ }
117+ defer func () { _ , _ = goDB .Exec ("PRAGMA defer_foreign_keys = OFF" ) }()
125118 case schemas .POSTGRES :
126119 f .quoteObject = func (s string ) string { return fmt .Sprintf (`"%s"` , s ) }
127120 f .paramPlaceholder = func (idx int ) string { return fmt .Sprintf (`$%d` , idx ) }
@@ -141,13 +134,20 @@ func (f *fixturesLoader) Load() error {
141134 f .opts .Files = append (f .opts .Files , e .Name ())
142135 }
143136 }
137+
138+ tx , err := goDB .Begin ()
139+ if err != nil {
140+ return err
141+ }
142+ defer func () { _ = tx .Rollback () }()
143+
144144 for _ , file := range f .opts .Files {
145145 if ! filepath .IsAbs (file ) {
146146 file = filepath .Join (f .opts .Dir , file )
147147 }
148- if err := f .loadFixtures (file ); err != nil {
148+ if err := f .loadFixtures (tx , file ); err != nil {
149149 return fmt .Errorf ("failed to load fixtures from %s: %w" , file , err )
150150 }
151151 }
152- return nil
152+ return tx . Commit ()
153153}
0 commit comments