Skip to content

Commit 76b5aef

Browse files
authored
More func tests (#1119)
* Add func tests for sql functions * Update and improve docstrings for consistency
1 parent 896c8a6 commit 76b5aef

File tree

18 files changed

+1636
-489
lines changed

18 files changed

+1636
-489
lines changed

src/datachain/func/aggregate.py

Lines changed: 64 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,78 +1,89 @@
1-
from typing import Optional
1+
from typing import Optional, Union
22

33
from sqlalchemy import func as sa_func
44

5+
from datachain.query.schema import Column
56
from datachain.sql.functions import aggregate
67

78
from .func import Func
89

910

10-
def count(col: Optional[str] = None) -> Func:
11+
def count(col: Optional[Union[str, Column]] = None) -> Func:
1112
"""
12-
Returns the COUNT aggregate SQL function for the given column name.
13+
Returns a COUNT aggregate SQL function for the specified column.
1314
14-
The COUNT function returns the number of rows in a table.
15+
The COUNT function returns the number of rows, optionally filtered
16+
by a specific column.
1517
1618
Args:
17-
col (str, optional): The name of the column for which to count rows.
18-
If not provided, it defaults to counting all rows.
19+
col (str | Column, optional): The column to count.
20+
If omitted, counts all rows.
21+
The column can be specified as a string or a `Column` object.
1922
2023
Returns:
21-
Func: A Func object that represents the COUNT aggregate function.
24+
Func: A `Func` object representing the COUNT aggregate function.
2225
2326
Example:
2427
```py
2528
dc.group_by(
26-
count=func.count(),
29+
count1=func.count(),
30+
count2=func.count("signal.id"),
31+
count3=func.count(dc.C("signal.category")),
2732
partition_by="signal.category",
2833
)
2934
```
3035
3136
Notes:
32-
- Result column will always be of type int.
37+
- The result column will always have an integer type.
3338
"""
3439
return Func(
35-
"count", inner=sa_func.count, cols=[col] if col else None, result_type=int
40+
"count",
41+
inner=sa_func.count,
42+
cols=[col] if col is not None else None,
43+
result_type=int,
3644
)
3745

3846

39-
def sum(col: str) -> Func:
47+
def sum(col: Union[str, Column]) -> Func:
4048
"""
41-
Returns the SUM aggregate SQL function for the given column name.
49+
Returns the SUM aggregate SQL function for the specified column.
4250
4351
The SUM function returns the total sum of a numeric column in a table.
4452
It sums up all the values for the specified column.
4553
4654
Args:
47-
col (str): The name of the column for which to calculate the sum.
55+
col (str | Column): The name of the column for which to calculate the sum.
56+
The column can be specified as a string or a `Column` object.
4857
4958
Returns:
50-
Func: A Func object that represents the SUM aggregate function.
59+
Func: A `Func` object that represents the SUM aggregate function.
5160
5261
Example:
5362
```py
5463
dc.group_by(
5564
files_size=func.sum("file.size"),
65+
total_size=func.sum(dc.C("size")),
5666
partition_by="signal.category",
5767
)
5868
```
5969
6070
Notes:
6171
- The `sum` function should be used on numeric columns.
62-
- Result column type will be the same as the input column type.
72+
- The result column type will be the same as the input column type.
6373
"""
6474
return Func("sum", inner=sa_func.sum, cols=[col])
6575

6676

67-
def avg(col: str) -> Func:
77+
def avg(col: Union[str, Column]) -> Func:
6878
"""
69-
Returns the AVG aggregate SQL function for the given column name.
79+
Returns the AVG aggregate SQL function for the specified column.
7080
7181
The AVG function returns the average of a numeric column in a table.
7282
It calculates the mean of all values in the specified column.
7383
7484
Args:
75-
col (str): The name of the column for which to calculate the average.
85+
col (str | Column): The name of the column for which to calculate the average.
86+
Column can be specified as a string or a `Column` object.
7687
7788
Returns:
7889
Func: A Func object that represents the AVG aggregate function.
@@ -81,26 +92,28 @@ def avg(col: str) -> Func:
8192
```py
8293
dc.group_by(
8394
avg_file_size=func.avg("file.size"),
95+
avg_signal_value=func.avg(dc.C("signal.value")),
8496
partition_by="signal.category",
8597
)
8698
```
8799
88100
Notes:
89101
- The `avg` function should be used on numeric columns.
90-
- Result column will always be of type float.
102+
- The result column will always be of type float.
91103
"""
92104
return Func("avg", inner=aggregate.avg, cols=[col], result_type=float)
93105

94106

95-
def min(col: str) -> Func:
107+
def min(col: Union[str, Column]) -> Func:
96108
"""
97-
Returns the MIN aggregate SQL function for the given column name.
109+
Returns the MIN aggregate SQL function for the specified column.
98110
99111
The MIN function returns the smallest value in the specified column.
100112
It can be used on both numeric and non-numeric columns to find the minimum value.
101113
102114
Args:
103-
col (str): The name of the column for which to find the minimum value.
115+
col (str | Column): The name of the column for which to find the minimum value.
116+
Column can be specified as a string or a `Column` object.
104117
105118
Returns:
106119
Func: A Func object that represents the MIN aggregate function.
@@ -109,26 +122,28 @@ def min(col: str) -> Func:
109122
```py
110123
dc.group_by(
111124
smallest_file=func.min("file.size"),
125+
min_signal=func.min(dc.C("signal")),
112126
partition_by="signal.category",
113127
)
114128
```
115129
116130
Notes:
117131
- The `min` function can be used with numeric, date, and string columns.
118-
- Result column will have the same type as the input column.
132+
- The result column will have the same type as the input column.
119133
"""
120134
return Func("min", inner=sa_func.min, cols=[col])
121135

122136

123-
def max(col: str) -> Func:
137+
def max(col: Union[str, Column]) -> Func:
124138
"""
125139
Returns the MAX aggregate SQL function for the given column name.
126140
127141
The MAX function returns the smallest value in the specified column.
128142
It can be used on both numeric and non-numeric columns to find the maximum value.
129143
130144
Args:
131-
col (str): The name of the column for which to find the maximum value.
145+
col (str | Column): The name of the column for which to find the maximum value.
146+
Column can be specified as a string or a `Column` object.
132147
133148
Returns:
134149
Func: A Func object that represents the MAX aggregate function.
@@ -137,18 +152,19 @@ def max(col: str) -> Func:
137152
```py
138153
dc.group_by(
139154
largest_file=func.max("file.size"),
155+
max_signal=func.max(dc.C("signal")),
140156
partition_by="signal.category",
141157
)
142158
```
143159
144160
Notes:
145161
- The `max` function can be used with numeric, date, and string columns.
146-
- Result column will have the same type as the input column.
162+
- The result column will have the same type as the input column.
147163
"""
148164
return Func("max", inner=sa_func.max, cols=[col])
149165

150166

151-
def any_value(col: str) -> Func:
167+
def any_value(col: Union[str, Column]) -> Func:
152168
"""
153169
Returns the ANY_VALUE aggregate SQL function for the given column name.
154170
@@ -157,7 +173,9 @@ def any_value(col: str) -> Func:
157173
as long as it comes from one of the rows in the group.
158174
159175
Args:
160-
col (str): The name of the column from which to return an arbitrary value.
176+
col (str | Column): The name of the column from which to return
177+
an arbitrary value.
178+
Column can be specified as a string or a `Column` object.
161179
162180
Returns:
163181
Func: A Func object that represents the ANY_VALUE aggregate function.
@@ -166,20 +184,21 @@ def any_value(col: str) -> Func:
166184
```py
167185
dc.group_by(
168186
file_example=func.any_value("file.path"),
187+
signal_example=func.any_value(dc.C("signal.value")),
169188
partition_by="signal.category",
170189
)
171190
```
172191
173192
Notes:
174193
- The `any_value` function can be used with any type of column.
175-
- Result column will have the same type as the input column.
194+
- The result column will have the same type as the input column.
176195
- The result of `any_value` is non-deterministic,
177196
meaning it may return different values for different executions.
178197
"""
179198
return Func("any_value", inner=aggregate.any_value, cols=[col])
180199

181200

182-
def collect(col: str) -> Func:
201+
def collect(col: Union[str, Column]) -> Func:
183202
"""
184203
Returns the COLLECT aggregate SQL function for the given column name.
185204
@@ -188,7 +207,8 @@ def collect(col: str) -> Func:
188207
into a collection, often for further processing or aggregation.
189208
190209
Args:
191-
col (str): The name of the column from which to collect values.
210+
col (str | Column): The name of the column from which to collect values.
211+
Column can be specified as a string or a `Column` object.
192212
193213
Returns:
194214
Func: A Func object that represents the COLLECT aggregate function.
@@ -197,18 +217,19 @@ def collect(col: str) -> Func:
197217
```py
198218
dc.group_by(
199219
signals=func.collect("signal"),
220+
file_paths=func.collect(dc.C("file.path")),
200221
partition_by="signal.category",
201222
)
202223
```
203224
204225
Notes:
205226
- The `collect` function can be used with numeric and string columns.
206-
- Result column will have an array type.
227+
- The result column will have an array type.
207228
"""
208229
return Func("collect", inner=aggregate.collect, cols=[col], is_array=True)
209230

210231

211-
def concat(col: str, separator="") -> Func:
232+
def concat(col: Union[str, Column], separator="") -> Func:
212233
"""
213234
Returns the CONCAT aggregate SQL function for the given column name.
214235
@@ -217,9 +238,10 @@ def concat(col: str, separator="") -> Func:
217238
into a single combined value.
218239
219240
Args:
220-
col (str): The name of the column from which to concatenate values.
241+
col (str | Column): The name of the column from which to concatenate values.
242+
Column can be specified as a string or a `Column` object.
221243
separator (str, optional): The separator to use between concatenated values.
222-
Defaults to an empty string.
244+
Defaults to an empty string.
223245
224246
Returns:
225247
Func: A Func object that represents the CONCAT aggregate function.
@@ -228,13 +250,14 @@ def concat(col: str, separator="") -> Func:
228250
```py
229251
dc.group_by(
230252
files=func.concat("file.path", separator=", "),
253+
signals=func.concat(dc.C("signal.name"), separator=" | "),
231254
partition_by="signal.category",
232255
)
233256
```
234257
235258
Notes:
236259
- The `concat` function can be used with string columns.
237-
- Result column will have a string type.
260+
- The result column will have a string type.
238261
"""
239262

240263
def inner(arg):
@@ -325,7 +348,7 @@ def dense_rank() -> Func:
325348
return Func("dense_rank", inner=sa_func.dense_rank, result_type=int, is_window=True)
326349

327350

328-
def first(col: str) -> Func:
351+
def first(col: Union[str, Column]) -> Func:
329352
"""
330353
Returns the FIRST_VALUE window function for SQL queries.
331354
@@ -334,7 +357,9 @@ def first(col: str) -> Func:
334357
and can be useful for retrieving the leading value in a group of rows.
335358
336359
Args:
337-
col (str): The name of the column from which to retrieve the first value.
360+
col (str | Column): The name of the column from which to retrieve
361+
the first value.
362+
Column can be specified as a string or a `Column` object.
338363
339364
Returns:
340365
Func: A Func object that represents the FIRST_VALUE window function.
@@ -344,6 +369,7 @@ def first(col: str) -> Func:
344369
window = func.window(partition_by="signal.category", order_by="created_at")
345370
dc.mutate(
346371
first_file=func.first("file.path").over(window),
372+
first_signal=func.first(dc.C("signal.value")).over(window),
347373
)
348374
```
349375

0 commit comments

Comments
 (0)