1515from __future__ import annotations
1616
1717import functools
18+ import typing
1819
1920import sqlglot .expressions as sge
2021
2930
3031@register_unary_op (ops .capitalize_op )
3132def _ (expr : TypedExpr ) -> sge .Expression :
32- return sge .Initcap (this = expr .expr )
33+ return sge .Initcap (this = expr .expr , expression = sge . convert ( "" ) )
3334
3435
3536@register_unary_op (ops .StrContainsOp , pass_op = True )
@@ -44,9 +45,17 @@ def _(expr: TypedExpr, op: ops.StrContainsRegexOp) -> sge.Expression:
4445
4546@register_unary_op (ops .StrExtractOp , pass_op = True )
4647def _ (expr : TypedExpr , op : ops .StrExtractOp ) -> sge .Expression :
47- return sge .RegexpExtract (
48- this = expr .expr , expression = sge .convert (op .pat ), group = sge .convert (op .n )
49- )
48+ # Cannot use BigQuery's REGEXP_EXTRACT function, which only allows one
49+ # capturing group.
50+ pat_expr = sge .convert (op .pat )
51+ if op .n != 0 :
52+ pat_expr = sge .func ("CONCAT" , sge .convert (".*?" ), pat_expr , sge .convert (".*" ))
53+ else :
54+ pat_expr = sge .func ("CONCAT" , sge .convert (".*?(" ), pat_expr , sge .convert (").*" ))
55+
56+ rex_replace = sge .func ("REGEXP_REPLACE" , expr .expr , pat_expr , sge .convert (r"\1" ))
57+ rex_contains = sge .func ("REGEXP_CONTAINS" , expr .expr , sge .convert (op .pat ))
58+ return sge .If (this = rex_contains , true = rex_replace , false = sge .null ())
5059
5160
5261@register_unary_op (ops .StrFindOp , pass_op = True )
@@ -75,47 +84,43 @@ def _(expr: TypedExpr, op: ops.StrFindOp) -> sge.Expression:
7584
7685@register_unary_op (ops .StrLstripOp , pass_op = True )
7786def _ (expr : TypedExpr , op : ops .StrLstripOp ) -> sge .Expression :
78- return sge .Trim (this = expr .expr , expression = sge .convert (op .to_strip ), side = "LEFT" )
87+ return sge .func ("LTRIM" , expr .expr , sge .convert (op .to_strip ))
88+
89+
90+ @register_unary_op (ops .StrRstripOp , pass_op = True )
91+ def _ (expr : TypedExpr , op : ops .StrRstripOp ) -> sge .Expression :
92+ return sge .func ("RTRIM" , expr .expr , sge .convert (op .to_strip ))
7993
8094
8195@register_unary_op (ops .StrPadOp , pass_op = True )
8296def _ (expr : TypedExpr , op : ops .StrPadOp ) -> sge .Expression :
83- pad_length = sge .func (
84- "GREATEST" , sge .Length (this = expr .expr ), sge .convert (op .length )
85- )
97+ expr_length = sge .Length (this = expr .expr )
98+ fillchar = sge .convert (op .fillchar )
99+ pad_length = sge .func ("GREATEST" , expr_length , sge .convert (op .length ))
100+
86101 if op .side == "left" :
87- return sge .func (
88- "LPAD" ,
89- expr .expr ,
90- pad_length ,
91- sge .convert (op .fillchar ),
92- )
102+ return sge .func ("LPAD" , expr .expr , pad_length , fillchar )
93103 elif op .side == "right" :
94- return sge .func (
95- "RPAD" ,
96- expr .expr ,
97- pad_length ,
98- sge .convert (op .fillchar ),
99- )
104+ return sge .func ("RPAD" , expr .expr , pad_length , fillchar )
100105 else : # side == both
101- lpad_amount = sge .Cast (
102- this = sge .func (
103- "SAFE_DIVIDE" ,
104- sge .Sub (this = pad_length , expression = sge .Length (this = expr .expr )),
105- sge .convert (2 ),
106- ),
107- to = "INT64" ,
108- ) + sge .Length (this = expr .expr )
106+ lpad_amount = (
107+ sge .Cast (
108+ this = sge .Floor (
109+ this = sge .func (
110+ "SAFE_DIVIDE" ,
111+ sge .Sub (this = pad_length , expression = expr_length ),
112+ sge .convert (2 ),
113+ )
114+ ),
115+ to = "INT64" ,
116+ )
117+ + expr_length
118+ )
109119 return sge .func (
110120 "RPAD" ,
111- sge .func (
112- "LPAD" ,
113- expr .expr ,
114- lpad_amount ,
115- sge .convert (op .fillchar ),
116- ),
121+ sge .func ("LPAD" , expr .expr , lpad_amount , fillchar ),
117122 pad_length ,
118- sge . convert ( op . fillchar ) ,
123+ fillchar ,
119124 )
120125
121126
@@ -148,12 +153,15 @@ def _(expr: TypedExpr) -> sge.Expression:
148153
149154@register_unary_op (ops .isdecimal_op )
150155def _ (expr : TypedExpr ) -> sge .Expression :
151- return sge .RegexpLike (this = expr .expr , expression = sge .convert (r"^\d +$" ))
156+ return sge .RegexpLike (this = expr .expr , expression = sge .convert (r"^(\p{Nd}) +$" ))
152157
153158
154159@register_unary_op (ops .isdigit_op )
155160def _ (expr : TypedExpr ) -> sge .Expression :
156- return sge .RegexpLike (this = expr .expr , expression = sge .convert (r"^\p{Nd}+$" ))
161+ regexp_pattern = (
162+ r"^[\p{Nd}\x{00B9}\x{00B2}\x{00B3}\x{2070}\x{2074}-\x{2079}\x{2080}-\x{2089}]+$"
163+ )
164+ return sge .RegexpLike (this = expr .expr , expression = sge .convert (regexp_pattern ))
157165
158166
159167@register_unary_op (ops .islower_op )
@@ -224,11 +232,6 @@ def _(expr: TypedExpr) -> sge.Expression:
224232 return sge .func ("REVERSE" , expr .expr )
225233
226234
227- @register_unary_op (ops .StrRstripOp , pass_op = True )
228- def _ (expr : TypedExpr , op : ops .StrRstripOp ) -> sge .Expression :
229- return sge .Trim (this = expr .expr , expression = sge .convert (op .to_strip ), side = "RIGHT" )
230-
231-
232235@register_unary_op (ops .StartsWithOp , pass_op = True )
233236def _ (expr : TypedExpr , op : ops .StartsWithOp ) -> sge .Expression :
234237 if not op .pat :
@@ -253,27 +256,12 @@ def _(expr: TypedExpr, op: ops.StringSplitOp) -> sge.Expression:
253256
254257@register_unary_op (ops .StrGetOp , pass_op = True )
255258def _ (expr : TypedExpr , op : ops .StrGetOp ) -> sge .Expression :
256- return sge .Substring (
257- this = expr .expr ,
258- start = sge .convert (op .i + 1 ),
259- length = sge .convert (1 ),
260- )
259+ return string_index (expr , op .i )
261260
262261
263262@register_unary_op (ops .StrSliceOp , pass_op = True )
264263def _ (expr : TypedExpr , op : ops .StrSliceOp ) -> sge .Expression :
265- start = op .start + 1 if op .start is not None else None
266- if op .end is None :
267- length = None
268- elif op .start is None :
269- length = op .end
270- else :
271- length = op .end - op .start
272- return sge .Substring (
273- this = expr .expr ,
274- start = sge .convert (start ) if start is not None else None ,
275- length = sge .convert (length ) if length is not None else None ,
276- )
264+ return string_slice (expr , op .start , op .end )
277265
278266
279267@register_unary_op (ops .upper_op )
@@ -314,3 +302,79 @@ def _(expr: TypedExpr, op: ops.ZfillOp) -> sge.Expression:
314302 ],
315303 default = sge .func ("LPAD" , expr .expr , length_expr , sge .convert ("0" )),
316304 )
305+
306+
307+ def string_index (expr : TypedExpr , index : int ) -> sge .Expression :
308+ sub_str = sge .Substring (
309+ this = expr .expr ,
310+ start = sge .convert (index + 1 ),
311+ length = sge .convert (1 ),
312+ )
313+ return sge .If (
314+ this = sge .NEQ (this = sub_str , expression = sge .convert ("" )),
315+ true = sub_str ,
316+ false = sge .Null (),
317+ )
318+
319+
320+ def string_slice (
321+ expr : TypedExpr , op_start : typing .Optional [int ], op_end : typing .Optional [int ]
322+ ) -> sge .Expression :
323+ column_length = sge .Length (this = expr .expr )
324+ if op_start is None :
325+ start = 0
326+ else :
327+ start = op_start
328+
329+ start_expr = sge .convert (start ) if start < 0 else sge .convert (start + 1 )
330+ length_expr : typing .Optional [sge .Expression ]
331+ if op_end is None :
332+ length_expr = None
333+ elif op_end < 0 :
334+ if start < 0 :
335+ start_expr = sge .Greatest (
336+ expressions = [
337+ sge .convert (1 ),
338+ column_length + sge .convert (start + 1 ),
339+ ]
340+ )
341+ length_expr = sge .Greatest (
342+ expressions = [
343+ sge .convert (0 ),
344+ column_length + sge .convert (op_end ),
345+ ]
346+ ) - sge .Greatest (
347+ expressions = [
348+ sge .convert (0 ),
349+ column_length + sge .convert (start ),
350+ ]
351+ )
352+ else :
353+ length_expr = sge .Greatest (
354+ expressions = [
355+ sge .convert (0 ),
356+ column_length + sge .convert (op_end - start ),
357+ ]
358+ )
359+ else : # op.end >= 0
360+ if start < 0 :
361+ start_expr = sge .Greatest (
362+ expressions = [
363+ sge .convert (1 ),
364+ column_length + sge .convert (start + 1 ),
365+ ]
366+ )
367+ length_expr = sge .convert (op_end ) - sge .Greatest (
368+ expressions = [
369+ sge .convert (0 ),
370+ column_length + sge .convert (start ),
371+ ]
372+ )
373+ else :
374+ length_expr = sge .convert (op_end - start )
375+
376+ return sge .Substring (
377+ this = expr .expr ,
378+ start = start_expr ,
379+ length = length_expr ,
380+ )
0 commit comments