Skip to content

Commit 2ca19a4

Browse files
authored
use clone appropriately in setting min limit (#225)
1 parent 7bb9bae commit 2ca19a4

File tree

3 files changed

+34
-14
lines changed

3 files changed

+34
-14
lines changed

src/datachain/lib/dc.py

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -193,8 +193,6 @@ class DataChain(DatasetQuery):
193193
```
194194
"""
195195

196-
max_row_count: Optional[int] = None
197-
198196
DEFAULT_FILE_RECORD: ClassVar[dict] = {
199197
"source": "",
200198
"name": "",
@@ -1603,18 +1601,7 @@ def filter(self, *args) -> "Self":
16031601
@detach
16041602
def limit(self, n: int) -> "Self":
16051603
"""Return the first n rows of the chain."""
1606-
n = max(n, 0)
1607-
1608-
if self.max_row_count is None:
1609-
self.max_row_count = n
1610-
return super().limit(n)
1611-
1612-
limit = min(n, self.max_row_count)
1613-
if limit == self.max_row_count:
1614-
return self
1615-
1616-
self.max_row_count = limit
1617-
return super().limit(self.max_row_count)
1604+
return super().limit(n)
16181605

16191606
@detach
16201607
def offset(self, offset: int) -> "Self":

src/datachain/query/dataset.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1383,6 +1383,9 @@ def order_by(self, *args) -> "Self":
13831383
@detach
13841384
def limit(self, n: int) -> "Self":
13851385
query = self.clone(new_table=False)
1386+
for step in query.steps:
1387+
if isinstance(step, SQLLimit) and step.n < n:
1388+
return query
13861389
query.steps.append(SQLLimit(n))
13871390
return query
13881391

tests/unit/lib/test_datachain.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1228,3 +1228,33 @@ class Nested(BaseModel):
12281228
traces_double=[[{"x": 0.5, "y": 0.5}], [{"x": 0.5, "y": 0.5}]],
12291229
)
12301230
]
1231+
1232+
1233+
def test_min_limit():
1234+
dc = DataChain.from_values(a=[1, 2, 3, 4, 5])
1235+
assert dc.count() == 5
1236+
assert dc.limit(4).count() == 4
1237+
assert dc.count() == 5
1238+
assert dc.limit(1).count() == 1
1239+
assert dc.count() == 5
1240+
assert dc.limit(2).limit(3).count() == 2
1241+
assert dc.count() == 5
1242+
assert dc.limit(3).limit(2).count() == 2
1243+
assert dc.count() == 5
1244+
1245+
1246+
def test_show_limit():
1247+
dc = DataChain.from_values(a=[1, 2, 3, 4, 5])
1248+
assert dc.count() == 5
1249+
assert dc.limit(4).count() == 4
1250+
dc.show(1)
1251+
assert dc.count() == 5
1252+
assert dc.limit(1).count() == 1
1253+
dc.show(1)
1254+
assert dc.count() == 5
1255+
assert dc.limit(2).limit(3).count() == 2
1256+
dc.show(1)
1257+
assert dc.count() == 5
1258+
assert dc.limit(3).limit(2).count() == 2
1259+
dc.show(1)
1260+
assert dc.count() == 5

0 commit comments

Comments
 (0)