Skip to content

Commit aa95407

Browse files
authored
Merge pull request #142 from TidierOrg/case_whenfix
fix `case_when` syntax
2 parents 1c53865 + b901de7 commit aa95407

File tree

5 files changed

+51
-59
lines changed

5 files changed

+51
-59
lines changed

NEWS.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
## v.8.7 - 2025-07-07
33
- AWS Athena backend bug fixes
44
- add `temp` option to `@create_table`, default is `true`
5+
- adjust `case_when` syntax to match TidierData
56

67
## v.8.6 - 2025-05-05
78
- add `@pivot_longer`

README.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,8 @@ TidierDB.jl currently supports the following top-level macros:
4040
|----------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
4141
| **Data Manipulation** | `@arrange`, `@group_by`, `@filter`, `@select`, `@mutate` (supports `across`), `@summarize`/`@summarise` (supports `across`), `@distinct`, `@relocate`, `@transmute` |
4242
| **Joining/Setting** | `@left_join`, `@right_join`, `@inner_join`, `@anti_join`, `@full_join`, `@semi_join`, `@union`, `@union_all`, `@intersect`, `@setdiff` |
43-
| **Slice and Order** | `@slice_min`, `@slice_max`, `@slice_sample`, `@order`, `@window_order`, `@window_frame` |
44-
| **Utility** | `@show_query`, `@collect`, `@head`, `@count`, `@drop_missing`, `show_tables`, `@create_view` , `drop_view` |
43+
| **Slice and Order** | `@slice_min`, `@slice_max`, `@slice_sample`, `@arrange`, `@window_order`, `@window_frame` |
44+
| **Utility** | `@show_query`, `@collect`, `@head`, `@count`, `@drop_missing`, `show_tables`, `@create_table`, `@create_view`, `drop_view` |
4545
| **Helper Functions** | `across`, `desc`, `if_else`, `case_when`, `n`, `starts_with`, `ends_with`, `contains`, `as_float`, `as_integer`, `as_string`, `is_missing`, `missing_if`, `replace_missing` |
4646
| **TidierStrings.jl Functions** | `str_detect`, `str_replace`, `str_replace_all`, `str_remove_all`, `str_remove` |
4747
| **TidierDates.jl Functions** | `year`, `month`, `day`, `hour`, `min`, `second`, `floor_date`, `difftime`, `mdy`, `ymd`, `dmy` |
@@ -80,8 +80,8 @@ mtcars = DB.dt(db, path_or_name);
8080
DB.@mutate(mpg_squared = mpg^2,
8181
mpg_rounded = round(mpg),
8282
mpg_efficiency = case_when(
83-
mpg >= cyl^2 , "efficient",
84-
mpg < 15.2 , "inefficient",
83+
mpg >= cyl^2 => "efficient",
84+
mpg < 15.2 => "inefficient",
8585
"moderate"))
8686
DB.@filter(mpg_efficiency in ("moderate", "efficient"))
8787
DB.@arrange(desc(mpg_rounded))

docs/examples/UserGuide/key_differences.jl

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -50,15 +50,3 @@ end
5050
_by = groups)
5151
@collect
5252
end
53-
54-
# ## Differences in `case_when()`
55-
56-
# In TidierDB, after the clause is completed, the result for the new column should is separated by a comma `,`
57-
# in contrast to TidierData.jl, where the result for the new column is separated by a `=>` .
58-
59-
@chain dfv begin
60-
@mutate(new_col = case_when(percent > .5, "Pass", # in TidierData, percent > .5 => "Pass",
61-
percent <= .5, "Try Again", # percent <= .5 => "Try Again"
62-
true, "middle"))
63-
@collect
64-
end

docs/src/index.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,8 +144,8 @@ We can replace `DB.collect()` with `DB.@show_query` to reveal the underlying SQL
144144
DB.@mutate(mpg_squared = mpg^2,
145145
mpg_rounded = round(mpg),
146146
mpg_efficiency = case_when(
147-
mpg >= cyl^2 , "efficient",
148-
mpg < 15.2 , "inefficient",
147+
mpg >= cyl^2 => "efficient",
148+
mpg < 15.2 => "inefficient",
149149
"moderate"))
150150
DB.@filter(mpg_efficiency in ("moderate", "efficient"))
151151
DB.@arrange(desc(mpg_rounded))

src/db_parsing.jl

Lines changed: 44 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -238,58 +238,61 @@ end
238238

239239
function parse_case_when(expr)
240240
MacroTools.postwalk(expr) do x
241-
# Ensure we're dealing with an Expr object
242-
if isa(x, Expr)
243-
# Check for a case_when expression
244-
if x.head == :call && x.args[1] == :case_when
245-
# Initialize components for building a SQL CASE expression
246-
sql_case_parts = ["CASE"]
247-
248-
# Iterate through the arguments of the case_when call, skipping the function name
249-
for i in 2:2:length(x.args)-1
250-
# Ensure we're only adding valid expressions
251-
cond = x.args[i]
252-
result = x.args[i + 1]
253-
254-
# Handle `missing` by converting it to `NULL`
255-
result_formatted = if result === :missing
256-
"NULL"
257-
elseif isa(result, String)
258-
"'$result'"
259-
else
260-
result
261-
end
241+
if isa(x, Expr) && x.head == :call && x.args[1] == :case_when
242+
sql_parts = ["CASE"]
243+
args = x.args[2:end]
244+
245+
expanded = Any[]
246+
default = nothing
247+
248+
249+
i = 1
250+
while i length(args)
251+
arg = args[i]
252+
253+
if isa(arg, Expr) && arg.head == :call && arg.args[1] == :(=>)
254+
push!(expanded, arg.args[2], arg.args[3])
255+
i += 1
256+
257+
elseif isa(arg, Pair)
258+
push!(expanded, arg.first, arg.second)
259+
i += 1
262260

263-
# Append the WHEN-THEN part to the SQL CASE expression
264-
push!(sql_case_parts, "WHEN $(cond) THEN $(result_formatted)")
265-
end
266-
267-
# Handle the default case, the last argument
268-
default_result = x.args[end]
269-
default_result_formatted = if default_result === :missing
270-
"NULL"
271-
elseif isa(default_result, String)
272-
"'$default_result'"
273261
else
274-
default_result
262+
if i == length(args)
263+
default = arg
264+
i += 1
265+
else
266+
push!(expanded, arg, args[i + 1])
267+
i += 2
268+
end
275269
end
270+
end
276271

277-
# Append the ELSE part and the END
278-
push!(sql_case_parts, "ELSE $(default_result_formatted) END")
279-
280-
# Combine into a complete SQL CASE statement
281-
sql_case = join(sql_case_parts, " ")
282-
283-
# Directly return the SQL CASE statement string
284-
return sql_case
272+
for j in 1:2:length(expanded)
273+
cond = expanded[j]
274+
result = expanded[j + 1]
275+
276+
res_sql = result === :missing ? "NULL" :
277+
isa(result, String) ? "'$result'" : result
278+
push!(sql_parts, "WHEN $(cond) THEN $(res_sql)")
279+
end
280+
281+
if default !== nothing
282+
def_sql = default === :missing ? "NULL" :
283+
isa(default, String) ? "'$default'" : default
284+
push!(sql_parts, "ELSE $(def_sql)")
285285
end
286+
287+
push!(sql_parts, "END")
288+
return join(sql_parts, " ")
286289
end
287-
# Return the unmodified object if it's not an Expr or not a case_when call
288290
return x
289291
end
290292
end
291293

292294

295+
293296
#this fxn is not being tested, bc its only in backends. - i might be able to get rid of it entirely as well
294297
# COV_EXCL_START
295298
function parse_char_matching(expr)

0 commit comments

Comments
 (0)