@@ -16,7 +16,7 @@ type Args struct {
1616 // The default flavor used by `Args#Compile`
1717 Flavor Flavor
1818
19- args []interface {}
19+ argValues []interface {}
2020 namedArgs map [string ]int
2121 sqlNamedArgs map [string ]int
2222 onlyNamed bool
@@ -47,7 +47,7 @@ func (args *Args) Add(arg interface{}) string {
4747}
4848
4949func (args * Args ) add (arg interface {}) int {
50- idx := len (args .args )
50+ idx := len (args .argValues )
5151
5252 switch a := arg .(type ) {
5353 case sql.NamedArg :
@@ -56,7 +56,7 @@ func (args *Args) add(arg interface{}) int {
5656 }
5757
5858 if p , ok := args .sqlNamedArgs [a .Name ]; ok {
59- arg = args .args [p ]
59+ arg = args .argValues [p ]
6060 break
6161 }
6262
@@ -67,7 +67,7 @@ func (args *Args) add(arg interface{}) int {
6767 }
6868
6969 if p , ok := args .namedArgs [a .name ]; ok {
70- arg = args .args [p ]
70+ arg = args .argValues [p ]
7171 break
7272 }
7373
@@ -77,7 +77,7 @@ func (args *Args) add(arg interface{}) int {
7777 return idx
7878 }
7979
80- args .args = append (args .args , arg )
80+ args .argValues = append (args .argValues , arg )
8181 return idx
8282}
8383
@@ -97,55 +97,58 @@ func (args *Args) Compile(format string, initialValue ...interface{}) (query str
9797//
9898// See doc for `Compile` to learn details.
9999func (args * Args ) CompileWithFlavor (format string , flavor Flavor , initialValue ... interface {}) (query string , values []interface {}) {
100- buf := newStringBuilder ()
101100 idx := strings .IndexRune (format , '$' )
102101 offset := 0
103- values = initialValue
102+ ctx := & argsCompileContext {
103+ stringBuilder : newStringBuilder (),
104+ Flavor : flavor ,
105+ Values : initialValue ,
106+ }
104107
105- if flavor == invalidFlavor {
106- flavor = DefaultFlavor
108+ if ctx . Flavor == invalidFlavor {
109+ ctx . Flavor = DefaultFlavor
107110 }
108111
109112 for idx >= 0 && len (format ) > 0 {
110113 if idx > 0 {
111- buf .WriteString (format [:idx ])
114+ ctx .WriteString (format [:idx ])
112115 }
113116
114117 format = format [idx + 1 :]
115118
116119 // Treat the $ at the end of format is a normal $ rune.
117120 if len (format ) == 0 {
118- buf .WriteRune ('$' )
121+ ctx .WriteRune ('$' )
119122 break
120123 }
121124
122125 if r := format [0 ]; r == '$' {
123- buf .WriteRune ('$' )
126+ ctx .WriteRune ('$' )
124127 format = format [1 :]
125128 } else if r == '{' {
126- format , values = args .compileNamed (buf , flavor , format , values )
129+ format = args .compileNamed (ctx , format )
127130 } else if ! args .onlyNamed && '0' <= r && r <= '9' {
128- format , values , offset = args .compileDigits (buf , flavor , format , values , offset )
131+ format , offset = args .compileDigits (ctx , format , offset )
129132 } else if ! args .onlyNamed && r == '?' {
130- format , values , offset = args .compileSuccessive (buf , flavor , format [1 :], values , offset )
133+ format , offset = args .compileSuccessive (ctx , format [1 :], offset )
131134 } else {
132135 // For unknown $ expression format, treat it as a normal $ rune.
133- buf .WriteRune ('$' )
136+ ctx .WriteRune ('$' )
134137 }
135138
136139 idx = strings .IndexRune (format , '$' )
137140 }
138141
139142 if len (format ) > 0 {
140- buf .WriteString (format )
143+ ctx .WriteString (format )
141144 }
142145
143- query = buf .String ()
144- values = args .mergeSQLNamedArgs (values )
146+ query = ctx .String ()
147+ values = args .mergeSQLNamedArgs (ctx )
145148 return
146149}
147150
148- func (args * Args ) compileNamed (buf * stringBuilder , flavor Flavor , format string , values [] interface {}) ( string , [] interface {}) {
151+ func (args * Args ) compileNamed (ctx * argsCompileContext , format string ) string {
149152 i := 1
150153
151154 for ; i < len (format ) && format [i ] != '}' ; i ++ {
@@ -154,20 +157,20 @@ func (args *Args) compileNamed(buf *stringBuilder, flavor Flavor, format string,
154157
155158 // Invalid $ format. Ignore it.
156159 if i == len (format ) {
157- return format , values
160+ return format
158161 }
159162
160163 name := format [1 :i ]
161164 format = format [i + 1 :]
162165
163166 if p , ok := args .namedArgs [name ]; ok {
164- format , values , _ = args .compileSuccessive (buf , flavor , format , values , p )
167+ format , _ = args .compileSuccessive (ctx , format , p )
165168 }
166169
167- return format , values
170+ return format
168171}
169172
170- func (args * Args ) compileDigits (buf * stringBuilder , flavor Flavor , format string , values [] interface {}, offset int ) (string , [] interface {} , int ) {
173+ func (args * Args ) compileDigits (ctx * argsCompileContext , format string , offset int ) (string , int ) {
171174 i := 1
172175
173176 for ; i < len (format ) && '0' <= format [i ] && format [i ] <= '9' ; i ++ {
@@ -178,91 +181,37 @@ func (args *Args) compileDigits(buf *stringBuilder, flavor Flavor, format string
178181 format = format [i :]
179182
180183 if pointer , err := strconv .Atoi (digits ); err == nil {
181- return args .compileSuccessive (buf , flavor , format , values , pointer )
184+ return args .compileSuccessive (ctx , format , pointer )
182185 }
183186
184- return format , values , offset
187+ return format , offset
185188}
186189
187- func (args * Args ) compileSuccessive (buf * stringBuilder , flavor Flavor , format string , values [] interface {}, offset int ) (string , [] interface {} , int ) {
188- if offset >= len (args .args ) {
189- return format , values , offset
190+ func (args * Args ) compileSuccessive (ctx * argsCompileContext , format string , offset int ) (string , int ) {
191+ if offset >= len (args .argValues ) {
192+ return format , offset
190193 }
191194
192- arg := args .args [offset ]
193- values = args .compileArg (buf , flavor , values , arg )
194-
195- return format , values , offset + 1
196- }
197-
198- func (args * Args ) compileArg (buf * stringBuilder , flavor Flavor , values []interface {}, arg interface {}) []interface {} {
199- switch a := arg .(type ) {
200- case Builder :
201- var s string
202- s , values = a .BuildWithFlavor (flavor , values ... )
203- buf .WriteString (s )
204- case sql.NamedArg :
205- buf .WriteRune ('@' )
206- buf .WriteString (a .Name )
207- case rawArgs :
208- buf .WriteString (a .expr )
209- case listArgs :
210- if a .isTuple {
211- buf .WriteRune ('(' )
212- }
213-
214- if len (a .args ) > 0 {
215- values = args .compileArg (buf , flavor , values , a .args [0 ])
216- }
217-
218- for i := 1 ; i < len (a .args ); i ++ {
219- buf .WriteString (", " )
220- values = args .compileArg (buf , flavor , values , a .args [i ])
221- }
222-
223- if a .isTuple {
224- buf .WriteRune (')' )
225- }
226-
227- default :
228- switch flavor {
229- case MySQL , SQLite , CQL , ClickHouse , Presto , Informix :
230- buf .WriteRune ('?' )
231- case PostgreSQL :
232- fmt .Fprintf (buf , "$%d" , len (values )+ 1 )
233- case SQLServer :
234- fmt .Fprintf (buf , "@p%d" , len (values )+ 1 )
235- case Oracle :
236- fmt .Fprintf (buf , ":%d" , len (values )+ 1 )
237- default :
238- panic (fmt .Errorf ("Args.CompileWithFlavor: invalid flavor %v (%v)" , flavor , int (flavor )))
239- }
240-
241- namedValues := parseNamedArgs (values )
195+ arg := args .argValues [offset ]
196+ ctx .WriteValue (arg )
242197
243- if n := len (namedValues ); n == 0 {
244- values = append (values , arg )
245- } else {
246- index := len (values ) - n
247- values = append (values [:index + 1 ], namedValues ... )
248- values [index ] = arg
249- }
250- }
251-
252- return values
198+ return format , offset + 1
253199}
254200
255- func (args * Args ) mergeSQLNamedArgs (values [] interface {} ) []interface {} {
256- if len (args .sqlNamedArgs ) == 0 {
257- return values
201+ func (args * Args ) mergeSQLNamedArgs (ctx * argsCompileContext ) []interface {} {
202+ if len (args .sqlNamedArgs ) == 0 && len ( ctx . NamedArgs ) == 0 {
203+ return ctx . Values
258204 }
259205
260- namedValues := parseNamedArgs ( values )
261- existingNames := make (map [string ]struct {}, len (namedValues ))
206+ values := ctx . Values
207+ existingNames := make (map [string ]struct {}, len (ctx . NamedArgs ))
262208
263- for _ , v := range namedValues {
264- if a , ok := v .(sql.NamedArg ); ok {
265- existingNames [a .Name ] = struct {}{}
209+ // Add all named args to values.
210+ // Remove duplicated named args in this step.
211+ for _ , arg := range ctx .NamedArgs {
212+ if _ , ok := existingNames [arg .Name ]; ! ok {
213+ existingNames [arg .Name ] = struct {}{}
214+ values = append (values , arg )
266215 }
267216 }
268217
@@ -280,19 +229,21 @@ func (args *Args) mergeSQLNamedArgs(values []interface{}) []interface{} {
280229 sort .Ints (ints )
281230
282231 for _ , i := range ints {
283- values = append (values , args .args [i ])
232+ values = append (values , args .argValues [i ])
284233 }
285234
286235 return values
287236}
288237
289- func parseNamedArgs (initialValue []interface {}) (namedValues []interface {}) {
238+ func parseNamedArgs (initialValue []interface {}) (values []interface {}, namedValues []sql. NamedArg ) {
290239 if len (initialValue ) == 0 {
291- return nil
240+ values = initialValue
241+ return
292242 }
293243
294244 // sql.NamedArgs must be placed at the end of the initial value.
295- i := len (initialValue )
245+ size := len (initialValue )
246+ i := size
296247
297248 for ; i > 0 ; i -- {
298249 switch initialValue [i - 1 ].(type ) {
@@ -303,6 +254,97 @@ func parseNamedArgs(initialValue []interface{}) (namedValues []interface{}) {
303254 break
304255 }
305256
306- namedValues = initialValue [i :]
257+ if i == size {
258+ values = initialValue
259+ return
260+ }
261+
262+ values = initialValue [:i ]
263+ namedValues = make ([]sql.NamedArg , 0 , size - i )
264+
265+ for ; i < size ; i ++ {
266+ namedValues = append (namedValues , initialValue [i ].(sql.NamedArg ))
267+ }
268+
307269 return
308270}
271+
272+ type argsCompileContext struct {
273+ * stringBuilder
274+
275+ Flavor Flavor
276+ Values []interface {}
277+ NamedArgs []sql.NamedArg
278+ }
279+
280+ func (ctx * argsCompileContext ) WriteValue (arg interface {}) {
281+ switch a := arg .(type ) {
282+ case Builder :
283+ s , values := a .BuildWithFlavor (ctx .Flavor , ctx .Values ... )
284+ ctx .WriteString (s )
285+
286+ // Add all values to ctx.
287+ // Named args must be located at the end of values.
288+ values , namedArgs := parseNamedArgs (values )
289+ ctx .Values = values
290+ ctx .NamedArgs = append (ctx .NamedArgs , namedArgs ... )
291+
292+ case sql.NamedArg :
293+ ctx .WriteRune ('@' )
294+ ctx .WriteString (a .Name )
295+ ctx .NamedArgs = append (ctx .NamedArgs , a )
296+
297+ case rawArgs :
298+ ctx .WriteString (a .expr )
299+
300+ case listArgs :
301+ if a .isTuple {
302+ ctx .WriteRune ('(' )
303+ }
304+
305+ if len (a .args ) > 0 {
306+ ctx .WriteValue (a .args [0 ])
307+ }
308+
309+ for i := 1 ; i < len (a .args ); i ++ {
310+ ctx .WriteString (", " )
311+ ctx .WriteValue (a .args [i ])
312+ }
313+
314+ if a .isTuple {
315+ ctx .WriteRune (')' )
316+ }
317+
318+ case condBuilder :
319+ a .Builder (ctx )
320+
321+ default :
322+ switch ctx .Flavor {
323+ case MySQL , SQLite , CQL , ClickHouse , Presto , Informix :
324+ ctx .WriteRune ('?' )
325+ case PostgreSQL :
326+ fmt .Fprintf (ctx , "$%d" , len (ctx .Values )+ 1 )
327+ case SQLServer :
328+ fmt .Fprintf (ctx , "@p%d" , len (ctx .Values )+ 1 )
329+ case Oracle :
330+ fmt .Fprintf (ctx , ":%d" , len (ctx .Values )+ 1 )
331+ default :
332+ panic (fmt .Errorf ("Args.CompileWithFlavor: invalid flavor %v (%v)" , ctx .Flavor , int (ctx .Flavor )))
333+ }
334+
335+ ctx .Values = append (ctx .Values , arg )
336+ }
337+ }
338+
339+ func (ctx * argsCompileContext ) WriteValues (values []interface {}, sep string ) {
340+ if len (values ) == 0 {
341+ return
342+ }
343+
344+ ctx .WriteValue (values [0 ])
345+
346+ for _ , v := range values [1 :] {
347+ ctx .WriteString (sep )
348+ ctx .WriteValue (v )
349+ }
350+ }
0 commit comments