44package unittest
55
66import (
7+ "database/sql"
78 "encoding/hex"
89 "fmt"
910 "os"
1011 "path/filepath"
1112 "strings"
1213
13- "github.com/yuin/goldmark/util"
14+ "code.gitea.io/gitea/modules/util"
15+
1416 "gopkg.in/yaml.v3"
1517 "xorm.io/xorm"
1618 "xorm.io/xorm/schemas"
1719)
1820
1921type fixturesLoader struct {
20- engine * xorm.Engine
21- opts FixturesOptions
22- quoteObject func (string ) string
22+ engine * xorm.Engine
23+ opts FixturesOptions
24+ quoteObject func (string ) string
25+ paramPlaceholder func (idx int ) string
2326}
2427
2528func (f * fixturesLoader ) prepareFieldValue (v any ) any {
2629 if s , ok := v .(string ); ok {
2730 if strings .HasPrefix (s , "::HEX::" ) {
28- b , _ := hex .DecodeString (s [8 :])
31+ b , _ := hex .DecodeString (s [7 :])
2932 return b
3033 }
3134 }
3235 return v
3336}
37+
38+ func (f * fixturesLoader ) mssqlTableHasIdentityColumn (db * sql.DB , tableName string ) (bool , error ) {
39+ row := db .QueryRow (`SELECT COUNT(*) FROM sys.identity_columns WHERE OBJECT_ID = OBJECT_ID(?)` , tableName )
40+ var count int
41+ if err := row .Scan (& count ); err != nil {
42+ return false , err
43+ }
44+ return count > 0 , nil
45+ }
46+
3447func (f * fixturesLoader ) loadFixtures (file string ) error {
3548 data , err := os .ReadFile (file )
3649 if err != nil {
@@ -49,16 +62,34 @@ func (f *fixturesLoader) loadFixtures(file string) error {
4962 return err
5063 }
5164
65+ goDB := f .engine .DB ().DB
66+ tx , err := goDB .Begin ()
67+ if err != nil {
68+ return err
69+ }
70+ defer func () {
71+ if tx != nil {
72+ _ = tx .Rollback ()
73+ }
74+ }()
75+
5276 switch f .engine .Dialect ().URI ().DBType {
5377 case schemas .MSSQL :
54- _ , _ = f .engine .Exec ("SET IDENTITY_INSERT [%s] ON" , tableName )
55- defer func () {
56- _ , _ = f .engine .Exec ("SET IDENTITY_INSERT [%s] OFF" , tableName )
57- }()
78+ hasIdentityColumn , err := f .mssqlTableHasIdentityColumn (goDB , tableName )
79+ if err != nil {
80+ return err
81+ }
82+ if hasIdentityColumn {
83+ _ , err = tx .Exec (fmt .Sprintf ("SET IDENTITY_INSERT %s ON" , tableNameQuoted ))
84+ if err != nil {
85+ return err
86+ }
87+ }
5888 }
5989
6090 var sqlBuf []byte
6191 var sqlArguments []any
92+ paramIdx := 1
6293 for _ , item := range fixtureItems {
6394 sqlBuf = append (sqlBuf , fmt .Sprintf ("INSERT INTO %s (" , tableNameQuoted )... )
6495 for k , v := range item {
@@ -69,22 +100,37 @@ func (f *fixturesLoader) loadFixtures(file string) error {
69100 sqlBuf = sqlBuf [:len (sqlBuf )- 1 ]
70101 sqlBuf = append (sqlBuf , ") VALUES (" ... )
71102 for range item {
72- sqlBuf = append (sqlBuf , "?," ... )
103+ sqlBuf = append (sqlBuf , f .paramPlaceholder (paramIdx )... )
104+ sqlBuf = append (sqlBuf , ',' )
105+ paramIdx ++
73106 }
74107 sqlBuf [len (sqlBuf )- 1 ] = ')'
75- _ , err = f . engine . Exec (append ([] any { util .BytesToReadOnlyString (sqlBuf )} , sqlArguments ... ) ... )
108+ _ , err = tx . Exec (util .UnsafeBytesToString (sqlBuf ), sqlArguments ... )
76109 if err != nil {
77110 return err
78111 }
79112 sqlBuf = sqlBuf [:0 ]
80113 sqlArguments = sqlArguments [:0 ]
81114 }
82- return nil
115+ err = tx .Commit ()
116+ tx = nil
117+ return err
83118}
84119
85120func (f * fixturesLoader ) Load () error {
86- f .quoteObject = func (s string ) string {
87- return fmt .Sprintf ("`%s`" , s )
121+ switch f .engine .Dialect ().URI ().DBType {
122+ case schemas .SQLITE :
123+ f .quoteObject = func (s string ) string { return fmt .Sprintf (`"%s"` , s ) }
124+ f .paramPlaceholder = func (idx int ) string { return "?" }
125+ case schemas .POSTGRES :
126+ f .quoteObject = func (s string ) string { return fmt .Sprintf (`"%s"` , s ) }
127+ f .paramPlaceholder = func (idx int ) string { return fmt .Sprintf (`$%d"` , idx ) }
128+ case schemas .MYSQL :
129+ f .quoteObject = func (s string ) string { return fmt .Sprintf ("`%s`" , s ) }
130+ f .paramPlaceholder = func (idx int ) string { return "?" }
131+ case schemas .MSSQL :
132+ f .quoteObject = func (s string ) string { return fmt .Sprintf ("[%s]" , s ) }
133+ f .paramPlaceholder = func (idx int ) string { return "?" }
88134 }
89135 if len (f .opts .Files ) == 0 {
90136 entries , err := os .ReadDir (f .opts .Dir )
0 commit comments