Skip to content

Commit a8b3e1f

Browse files
authored
Calls to AsXXX with column names as single items or single arrays should be the same (#99)
* If you call `AsAwkwardArray(['jetpt'])` or `AsAwkwardArray('jetpt')`, the produced `ast` will be identical. * This helps with caching down the line. Fixes #55
1 parent 3d3b554 commit a8b3e1f

File tree

2 files changed

+87
-0
lines changed

2 files changed

+87
-0
lines changed

func_adl/object_stream.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,9 @@ def AsPandasDF(
231231
232232
"""
233233
# To get Pandas use the ResultPandasDF function call.
234+
if isinstance(columns, str):
235+
columns = [columns]
236+
234237
return ObjectStream[ReturnedDataPlaceHolder](
235238
function_call("ResultPandasDF", [self._q_ast, as_ast(columns)])
236239
)
@@ -262,6 +265,9 @@ def AsROOTTTree(
262265
dataset. The order of the files back is consistent for different queries on the same
263266
dataset.
264267
"""
268+
if isinstance(columns, str):
269+
columns = [columns]
270+
265271
return ObjectStream[ReturnedDataPlaceHolder](
266272
function_call(
267273
"ResultTTree", [self._q_ast, as_ast(columns), as_ast(treename), as_ast(filename)]
@@ -296,6 +302,9 @@ def AsParquetFiles(
296302
result. The order of the files back is consistent for different queries on the same
297303
dataset.
298304
"""
305+
if isinstance(columns, str):
306+
columns = [columns]
307+
299308
return ObjectStream[ReturnedDataPlaceHolder](
300309
function_call("ResultParquet", [self._q_ast, as_ast(columns), as_ast(filename)])
301310
)
@@ -316,6 +325,9 @@ def AsAwkwardArray(
316325
317326
An `ObjectStream` with the `awkward` array data as its one and only element.
318327
"""
328+
if isinstance(columns, str):
329+
columns = [columns]
330+
319331
return ObjectStream[ReturnedDataPlaceHolder](
320332
function_call("ResultAwkwardArray", [self._q_ast, as_ast(columns)])
321333
)

tests/test_object_stream.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,25 @@ def test_simple_query():
7575
assert isinstance(r, ast.AST)
7676

7777

78+
def test_two_simple_query():
79+
r1 = (
80+
my_event()
81+
.SelectMany("lambda e: e.jets()")
82+
.Select("lambda j: j.pT()")
83+
.AsROOTTTree("junk.root", "analysis", "jetPT")
84+
.value()
85+
)
86+
r2 = (
87+
my_event()
88+
.SelectMany("lambda e: e.jets()")
89+
.Select("lambda j: j.pT()")
90+
.AsROOTTTree("junk.root", "analysis", ["jetPT"])
91+
.value()
92+
)
93+
94+
assert ast.dump(r1) == ast.dump(r2)
95+
96+
7897
def test_with_types():
7998
r1 = my_event_with_type().SelectMany(lambda e: e.Jets("jets"))
8099
r = r1.Select(lambda j: j.eta()).value()
@@ -115,6 +134,25 @@ def test_simple_query_parquet():
115134
assert isinstance(r, ast.AST)
116135

117136

137+
def test_two_simple_query_parquet():
138+
r1 = (
139+
my_event()
140+
.SelectMany("lambda e: e.jets()")
141+
.Select("lambda j: j.pT()")
142+
.AsParquetFiles("junk.root", "jetPT")
143+
.value()
144+
)
145+
r2 = (
146+
my_event()
147+
.SelectMany("lambda e: e.jets()")
148+
.Select("lambda j: j.pT()")
149+
.AsParquetFiles("junk.root", ["jetPT"])
150+
.value()
151+
)
152+
153+
assert ast.dump(r1) == ast.dump(r2)
154+
155+
118156
def test_simple_query_panda():
119157
r = (
120158
my_event()
@@ -126,6 +164,24 @@ def test_simple_query_panda():
126164
assert isinstance(r, ast.AST)
127165

128166

167+
def test_two_imple_query_panda():
168+
r1 = (
169+
my_event()
170+
.SelectMany("lambda e: e.jets()")
171+
.Select("lambda j: j.pT()")
172+
.AsPandasDF(["analysis"])
173+
.value()
174+
)
175+
r2 = (
176+
my_event()
177+
.SelectMany("lambda e: e.jets()")
178+
.Select("lambda j: j.pT()")
179+
.AsPandasDF(["analysis"])
180+
.value()
181+
)
182+
assert ast.dump(r1) == ast.dump(r2)
183+
184+
129185
def test_simple_query_awkward():
130186
r = (
131187
my_event()
@@ -137,6 +193,25 @@ def test_simple_query_awkward():
137193
assert isinstance(r, ast.AST)
138194

139195

196+
def test_two_similar_query_awkward():
197+
r1 = (
198+
my_event()
199+
.SelectMany("lambda e: e.jets()")
200+
.Select("lambda j: j.pT()")
201+
.AsAwkwardArray(["analysis"])
202+
.value()
203+
)
204+
r2 = (
205+
my_event()
206+
.SelectMany("lambda e: e.jets()")
207+
.Select("lambda j: j.pT()")
208+
.AsAwkwardArray("analysis")
209+
.value()
210+
)
211+
212+
assert ast.dump(r1) == ast.dump(r2)
213+
214+
140215
def test_metadata():
141216
r = (
142217
my_event()

0 commit comments

Comments
 (0)