Skip to content

Commit 21ffb4e

Browse files
authored
feat(null-propagation): required arg is Null->return Null for custom fn (#839)
1 parent 1a31b8e commit 21ffb4e

File tree

2 files changed

+106
-39
lines changed

2 files changed

+106
-39
lines changed

python/cocoindex/op.py

Lines changed: 65 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -114,8 +114,8 @@ def __call__(
114114
) -> tuple[dict[str, Any], Executor]:
115115
spec = _load_spec_from_engine(self._spec_cls, spec)
116116
executor = self._executor_cls(spec)
117-
result_type = executor.analyze(*args, **kwargs)
118-
return (encode_enriched_type(result_type), executor)
117+
result_type = executor.analyze_schema(*args, **kwargs)
118+
return (result_type, executor)
119119

120120

121121
_gpu_dispatch_lock = asyncio.Lock()
@@ -156,6 +156,12 @@ def _to_async_call(call: Callable[..., Any]) -> Callable[..., Awaitable[Any]]:
156156
return lambda *args, **kwargs: asyncio.to_thread(lambda: call(*args, **kwargs))
157157

158158

159+
@dataclasses.dataclass
160+
class _ArgInfo:
161+
decoder: Callable[[Any], Any]
162+
is_required: bool
163+
164+
159165
def _register_op_factory(
160166
category: OpCategory,
161167
expected_args: list[tuple[str, inspect.Parameter]],
@@ -176,37 +182,54 @@ def behavior_version(self) -> int | None:
176182
return op_args.behavior_version
177183

178184
class _WrappedClass(executor_cls, _Fallback): # type: ignore[misc]
179-
_args_decoders: list[Callable[[Any], Any]]
180-
_kwargs_decoders: dict[str, Callable[[Any], Any]]
185+
_args_info: list[_ArgInfo]
186+
_kwargs_info: dict[str, _ArgInfo]
181187
_acall: Callable[..., Awaitable[Any]]
182188

183189
def __init__(self, spec: Any) -> None:
184190
super().__init__()
185191
self.spec = spec
186192
self._acall = _to_async_call(super().__call__)
187193

188-
def analyze(
194+
def analyze_schema(
189195
self, *args: _engine.OpArgSchema, **kwargs: _engine.OpArgSchema
190196
) -> Any:
191197
"""
192198
Analyze the spec and arguments. In this phase, argument types should be validated.
193199
It should return the expected result type for the current op.
194200
"""
195-
self._args_decoders = []
196-
self._kwargs_decoders = {}
201+
self._args_info = []
202+
self._kwargs_info = {}
197203
attributes = []
198-
199-
def process_attribute(arg_name: str, arg: _engine.OpArgSchema) -> None:
204+
potentially_missing_required_arg = False
205+
206+
def process_arg(
207+
arg_name: str,
208+
arg_param: inspect.Parameter,
209+
actual_arg: _engine.OpArgSchema,
210+
) -> _ArgInfo:
211+
nonlocal potentially_missing_required_arg
200212
if op_args.arg_relationship is not None:
201213
related_attr, related_arg_name = op_args.arg_relationship
202214
if related_arg_name == arg_name:
203215
attributes.append(
204-
TypeAttr(related_attr.value, arg.analyzed_value)
216+
TypeAttr(related_attr.value, actual_arg.analyzed_value)
205217
)
218+
type_info = analyze_type_info(arg_param.annotation)
219+
decoder = make_engine_value_decoder(
220+
[arg_name], actual_arg.value_type["type"], type_info
221+
)
222+
is_required = not type_info.nullable
223+
if is_required and actual_arg.value_type.get("nullable", False):
224+
potentially_missing_required_arg = True
225+
return _ArgInfo(
226+
decoder=decoder,
227+
is_required=is_required,
228+
)
206229

207230
# Match arguments with parameters.
208231
next_param_idx = 0
209-
for arg in args:
232+
for actual_arg in args:
210233
if next_param_idx >= len(expected_args):
211234
raise ValueError(
212235
f"Too many arguments passed in: {len(args)} > {len(expected_args)}"
@@ -219,20 +242,13 @@ def process_attribute(arg_name: str, arg: _engine.OpArgSchema) -> None:
219242
raise ValueError(
220243
f"Too many positional arguments passed in: {len(args)} > {next_param_idx}"
221244
)
222-
self._args_decoders.append(
223-
make_engine_value_decoder(
224-
[arg_name],
225-
arg.value_type["type"],
226-
analyze_type_info(arg_param.annotation),
227-
)
228-
)
229-
process_attribute(arg_name, arg)
245+
self._args_info.append(process_arg(arg_name, arg_param, actual_arg))
230246
if arg_param.kind != inspect.Parameter.VAR_POSITIONAL:
231247
next_param_idx += 1
232248

233249
expected_kwargs = expected_args[next_param_idx:]
234250

235-
for kwarg_name, kwarg in kwargs.items():
251+
for kwarg_name, actual_arg in kwargs.items():
236252
expected_arg = next(
237253
(
238254
arg
@@ -254,12 +270,9 @@ def process_attribute(arg_name: str, arg: _engine.OpArgSchema) -> None:
254270
f"Unexpected keyword argument passed in: {kwarg_name}"
255271
)
256272
arg_param = expected_arg[1]
257-
self._kwargs_decoders[kwarg_name] = make_engine_value_decoder(
258-
[kwarg_name],
259-
kwarg.value_type["type"],
260-
analyze_type_info(arg_param.annotation),
273+
self._kwargs_info[kwarg_name] = process_arg(
274+
kwarg_name, arg_param, actual_arg
261275
)
262-
process_attribute(kwarg_name, kwarg)
263276

264277
missing_args = [
265278
name
@@ -280,32 +293,45 @@ def process_attribute(arg_name: str, arg: _engine.OpArgSchema) -> None:
280293
if len(missing_args) > 0:
281294
raise ValueError(f"Missing arguments: {', '.join(missing_args)}")
282295

283-
prepare_method = getattr(executor_cls, "analyze", None)
284-
if prepare_method is not None:
285-
result = prepare_method(self, *args, **kwargs)
296+
base_analyze_method = getattr(self, "analyze", None)
297+
if base_analyze_method is not None:
298+
result = base_analyze_method(self, *args, **kwargs)
286299
else:
287300
result = expected_return
288301
if len(attributes) > 0:
289302
result = Annotated[result, *attributes]
290-
return result
303+
304+
encoded_type = encode_enriched_type(result)
305+
if potentially_missing_required_arg:
306+
encoded_type["nullable"] = True
307+
return encoded_type
291308

292309
async def prepare(self) -> None:
293310
"""
294311
Prepare for execution.
295312
It's executed after `analyze` and before any `__call__` execution.
296313
"""
297-
setup_method = getattr(super(), "prepare", None)
298-
if setup_method is not None:
299-
await _to_async_call(setup_method)()
314+
prepare_method = getattr(super(), "prepare", None)
315+
if prepare_method is not None:
316+
await _to_async_call(prepare_method)()
300317

301318
async def __call__(self, *args: Any, **kwargs: Any) -> Any:
302-
decoded_args = (
303-
decoder(arg) for decoder, arg in zip(self._args_decoders, args)
304-
)
305-
decoded_kwargs = {
306-
arg_name: self._kwargs_decoders[arg_name](arg)
307-
for arg_name, arg in kwargs.items()
308-
}
319+
decoded_args = []
320+
for arg_info, arg in zip(self._args_info, args):
321+
if arg_info.is_required and arg is None:
322+
return None
323+
decoded_args.append(arg_info.decoder(arg))
324+
325+
decoded_kwargs = {}
326+
for kwarg_name, arg in kwargs.items():
327+
kwarg_info = self._kwargs_info.get(kwarg_name)
328+
if kwarg_info is None:
329+
raise ValueError(
330+
f"Unexpected keyword argument passed in: {kwarg_name}"
331+
)
332+
if kwarg_info.is_required and arg is None:
333+
return None
334+
decoded_kwargs[kwarg_name] = kwarg_info.decoder(arg)
309335

310336
if op_args.gpu:
311337
# For GPU executions, data-level parallelism is applied, so we don't want to

python/cocoindex/tests/test_transform_flow.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,3 +101,44 @@ async def test_for_each_transform_flow_async() -> None:
101101
}
102102

103103
assert result == expected, f"Expected {expected}, got {result}"
104+
105+
106+
def test_none_arg_yield_none_result() -> None:
107+
"""Test that None arguments yield None results."""
108+
109+
@cocoindex.op.function()
110+
def custom_fn(
111+
required_arg: int,
112+
optional_arg: int | None,
113+
required_kwarg: int,
114+
optional_kwarg: int | None,
115+
) -> int:
116+
return (
117+
required_arg + (optional_arg or 0) + required_kwarg + (optional_kwarg or 0)
118+
)
119+
120+
@cocoindex.transform_flow()
121+
def transform_flow(
122+
required_arg: cocoindex.DataSlice[int | None],
123+
optional_arg: cocoindex.DataSlice[int | None],
124+
required_kwarg: cocoindex.DataSlice[int | None],
125+
optional_kwarg: cocoindex.DataSlice[int | None],
126+
) -> cocoindex.DataSlice[int | None]:
127+
return required_arg.transform(
128+
custom_fn,
129+
optional_arg,
130+
required_kwarg=required_kwarg,
131+
optional_kwarg=optional_kwarg,
132+
)
133+
134+
result = transform_flow.eval(1, 2, 4, 8)
135+
assert result == 15, f"Expected 15, got {result}"
136+
137+
result = transform_flow.eval(1, None, 4, None)
138+
assert result == 5, f"Expected 5, got {result}"
139+
140+
result = transform_flow.eval(None, 2, 4, 8)
141+
assert result is None, f"Expected None, got {result}"
142+
143+
result = transform_flow.eval(1, 2, None, None)
144+
assert result is None, f"Expected None, got {result}"

0 commit comments

Comments
 (0)