Skip to content

Commit 6a81e20

Browse files
committed
fix: added the handling for all the substrait type in cover
1 parent 13c0f1c commit 6a81e20

File tree

2 files changed

+112
-0
lines changed

2 files changed

+112
-0
lines changed

src/substrait/extension_registry/signature_checker_helpers.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,14 @@ def _handle_parameterized_type(
235235
covered.decimal, parameterized_type, ["scale", "precision"], parameters
236236
)
237237

238+
if isinstance(parameterized_type, SubstraitTypeParser.PrecisionTimeContext):
239+
return kind == "precision_time" and check_integer_type_parameters(
240+
covered.precision_time,
241+
parameterized_type,
242+
["precision"],
243+
parameters,
244+
)
245+
238246
if isinstance(parameterized_type, SubstraitTypeParser.PrecisionTimestampContext):
239247
return kind == "precision_timestamp" and check_integer_type_parameters(
240248
covered.precision_timestamp,
@@ -251,6 +259,14 @@ def _handle_parameterized_type(
251259
parameters,
252260
)
253261

262+
if isinstance(parameterized_type, SubstraitTypeParser.PrecisionIntervalDayContext):
263+
return kind == "interval_day" and check_integer_type_parameters(
264+
covered.interval_day,
265+
parameterized_type,
266+
["precision"],
267+
parameters,
268+
)
269+
254270
if isinstance(parameterized_type, SubstraitTypeParser.ListContext):
255271
return kind == "list" and covers(
256272
covered.list.type,

tests/test_extension_registry.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -506,6 +506,102 @@ def test_registry_default_extensions_have_uri_mappings():
506506

507507
assert registry._uri_urn_bimap.get_urn(uri) == urn
508508

509+
def test_registry_default_extensions_lookup_function_multiply():
510+
"""Test that default extensions are loaded and functions can be looked up."""
511+
registry = ExtensionRegistry(load_default_extensions=True)
512+
513+
# Test looking up a function from the comparison extensions
514+
urn = "extension:io.substrait:functions_arithmetic"
515+
516+
# Look up a common comparison function (e.g., "multiply")
517+
result = registry.lookup_function(
518+
urn=urn,
519+
function_name="multiply",
520+
signature=[i8(nullable=False), i8(nullable=False)],
521+
)
522+
523+
assert result is not None, "Failed to lookup 'multiply' function from default extensions"
524+
entry, return_type = result
525+
526+
# Verify the function entry
527+
assert entry.name == "multiply"
528+
assert entry.urn == urn
529+
assert entry.function_type is not None
530+
assert entry.function_type.value == "scalar"
531+
assert isinstance(entry.anchor, int)
532+
533+
# Verify the URI-URN mapping exists
534+
uri = registry._uri_urn_bimap.get_uri(urn)
535+
assert uri is not None
536+
assert "https://github.com/substrait-io/substrait/blob/main/extensions" in uri
537+
assert "functions_arithmetic.yaml" in uri
538+
539+
# Test looking up a function across all URNs without specifying URN
540+
results = registry.list_functions_across_urns(
541+
function_name="multiply",
542+
signature=[i8(nullable=False), i8(nullable=False)],
543+
)
544+
545+
assert len(results) > 0, "Failed to find 'multiply' function across all URNs"
546+
547+
# Verify we found the same function
548+
found_entry = None
549+
for entry, return_type in results:
550+
if entry.urn == urn and entry.name == "multiply":
551+
found_entry = entry
552+
break
553+
554+
assert found_entry is not None, "multiply function not found in cross-URN search"
555+
assert found_entry.function_type.value == "scalar"
556+
557+
def test_registry_default_extensions_lookup_function():
558+
"""Test that default extensions are loaded and functions can be looked up."""
559+
registry = ExtensionRegistry(load_default_extensions=True)
560+
561+
# Test looking up a function from the comparison extensions
562+
urn = "extension:io.substrait:functions_comparison"
563+
564+
# Look up a common comparison function (e.g., "equal")
565+
result = registry.lookup_function(
566+
urn=urn,
567+
function_name="equal",
568+
signature=[i8(nullable=False), i8(nullable=False)],
569+
)
570+
571+
assert result is not None, "Failed to lookup 'equal' function from default extensions"
572+
entry, return_type = result
573+
574+
# Verify the function entry
575+
assert entry.name == "equal"
576+
assert entry.urn == urn
577+
assert entry.function_type is not None
578+
assert entry.function_type.value == "scalar"
579+
assert isinstance(entry.anchor, int)
580+
581+
# Verify the URI-URN mapping exists
582+
uri = registry._uri_urn_bimap.get_uri(urn)
583+
assert uri is not None
584+
assert "https://github.com/substrait-io/substrait/blob/main/extensions" in uri
585+
assert "functions_comparison.yaml" in uri
586+
587+
# Test looking up a function across all URNs without specifying URN
588+
results = registry.list_functions_across_urns(
589+
function_name="equal",
590+
signature=[i8(nullable=False), i8(nullable=False)],
591+
)
592+
593+
assert len(results) > 0, "Failed to find 'equal' function across all URNs"
594+
595+
# Verify we found the same function
596+
found_entry = None
597+
for entry, return_type in results:
598+
if entry.urn == urn and entry.name == "equal":
599+
found_entry = entry
600+
break
601+
602+
assert found_entry is not None, "Equal function not found in cross-URN search"
603+
assert found_entry.function_type.value == "scalar"
604+
509605

510606
def test_valid_urn_format():
511607
"""Test that valid URN formats are accepted."""

0 commit comments

Comments
 (0)