Skip to content

Commit fd09b45

Browse files
committed
Fix match_representations to handle ragged schemas w/ fixed values
1 parent 7a60ef9 commit fd09b45

File tree

1 file changed

+15
-9
lines changed

1 file changed

+15
-9
lines changed

merlin/systems/triton/conversions.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -83,17 +83,23 @@ def match_representations(schema: Schema, dict_array: Dict[str, Any]) -> Dict[st
8383
Dict[str, Any]
8484
A dictionary of NumPy or CuPy ndarrays with representations adjusted
8585
"""
86-
schema_names = tensor_names(schema)
87-
8886
aligned = {}
89-
for tensor_name in dict_array.keys():
90-
if tensor_name in schema_names:
91-
aligned[tensor_name] = dict_array[tensor_name]
87+
for col_name, col_schema in schema.column_schemas.items():
88+
if col_schema.is_ragged:
89+
vals_name = f"{col_name}__values"
90+
offs_name = f"{col_name}__offsets"
91+
92+
try:
93+
# Look for values and offsets that already exist
94+
aligned[vals_name] = dict_array[vals_name]
95+
aligned[offs_name] = dict_array[offs_name]
96+
except KeyError:
97+
# If you don't find them, create the offsets
98+
values, offsets = _to_values_offsets(dict_array[col_name])
99+
aligned[vals_name] = values
100+
aligned[offs_name] = offsets
92101
else:
93-
# Ragged columns with fixed shape values
94-
values, offsets = _to_values_offsets(dict_array[tensor_name])
95-
aligned[f"{tensor_name}__values"] = values
96-
aligned[f"{tensor_name}__offsets"] = offsets
102+
aligned[col_name] = dict_array[col_name]
97103

98104
return aligned
99105

0 commit comments

Comments
 (0)