Skip to content

Commit 8f54b9d

Browse files
committed
Apply suggestions from review
Signed-off-by: Rostan Tabet <rtabet@nvidia.com>
1 parent c498545 commit 8f54b9d

File tree

3 files changed

+9
-8
lines changed

3 files changed

+9
-8
lines changed

dali/python/nvidia/dali/experimental/dynamic/_invocation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def __init__(
9191
self._future: Optional[_Future] = None
9292
self._run_lock = threading.Lock()
9393
if caller_depth is None:
94-
caller_depth = 2 if getattr(self._operator, "_is_reader", False) else 4
94+
caller_depth = 3 if getattr(self._operator, "_is_reader", False) else 4
9595
self._call_stack = (
9696
capture_stack(caller_depth + 1)
9797
if _EvalMode.current().value <= _EvalMode.eager.value

dali/python/nvidia/dali/experimental/dynamic/_op_builder.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,8 @@ def init(self, max_batch_size, name, **kwargs):
203203
op_class.__base__.__init__(self, max_batch_size, name, **kwargs)
204204
if is_reader:
205205
self._tensor_args = {k: v for k, v in tensor_kwargs.items() if v is not None}
206+
if any(isinstance(v, Batch) for v in self._tensor_args.values()):
207+
raise ValueError("Readers cannot be constructed with batch keyword arguments")
206208
if stateful:
207209
self._call_id = 0
208210

@@ -287,11 +289,8 @@ def call(self, *raw_args, batch_size=None, _process_params=True, **raw_kwargs):
287289
if overlap:
288290
raise ValueError(
289291
f"Keyword argument{'s'[:len(overlap)^1]} {sorted(overlap)}"
290-
f" cannot be passed both in the constructor and __call__."
292+
f" cannot be passed in both in the constructor and __call__."
291293
)
292-
for arg in self._tensor_args.values():
293-
if isinstance(arg, Batch):
294-
raise ValueError("Readers cannot be constructed with batch keyword arguments")
295294
raw_kwargs = {**raw_kwargs, **self._tensor_args}
296295

297296
batch_size = _ops._infer_batch_size(batch_size, *raw_args, **raw_kwargs)

dali/python/nvidia/dali/experimental/dynamic/_ops.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -668,7 +668,7 @@ def next_epoch(self, batch_size=None, ctx: _eval_context.EvalContext | None = No
668668
else:
669669
return self._samples(ctx)
670670

671-
def _process_tensor_args(self, batch_size):
671+
def _process_tensor_args(self, batch_size: int | None):
672672
"""Converts stored tensor args to Batch/Tensor form for the given batch_size."""
673673
if not self._tensor_args:
674674
return {}
@@ -699,8 +699,10 @@ def _samples(self, ctx: _eval_context.EvalContext | None = None):
699699
self._actual_batch_size = 1
700700
if self._max_batch_size is None:
701701
self._max_batch_size = self._actual_batch_size
702-
self._init_backend(ctx, (), self._process_tensor_args(self._actual_batch_size))
703-
tensor_args = self._process_tensor_args(self._actual_batch_size)
702+
tensor_args = self._process_tensor_args(self._actual_batch_size)
703+
self._init_backend(ctx, (), tensor_args)
704+
else:
705+
tensor_args = self._process_tensor_args(self._actual_batch_size)
704706
meta = self._op_backend.GetReaderMeta()
705707
idx = 0
706708
padded_size = meta["epoch_size_padded"]

0 commit comments

Comments
 (0)