Skip to content

Commit 7a60ef9

Browse files
committed
Fix _to_values_offsets handle converting 1d values
This function is now used in more cases than it used to be and needs to be generalized a bit.
1 parent e94d2a9 commit 7a60ef9

File tree

1 file changed

+11
-4
lines changed

1 file changed

+11
-4
lines changed

merlin/systems/triton/conversions.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -111,12 +111,19 @@ def _to_values_offsets(array):
111111
values, offsets
112112
Tuple of values and offsets
113113
"""
114-
num_rows = array.shape[0]
115-
row_lengths = [array.shape[1]] * num_rows
116-
offsets = [0] + list(itertools.accumulate(row_lengths))
117114
array_lib = cp if cp and isinstance(array, cp.ndarray) else np
115+
shape = array.shape
116+
117+
values_shape = [-1]
118+
if len(shape) > 2:
119+
values_shape.extend(shape[2:])
120+
values = array.reshape(*values_shape)
121+
122+
num_rows = shape[0]
123+
row_lengths = [shape[1]] * num_rows if len(shape) > 1 else [1]
124+
offsets = [0] + list(itertools.accumulate(row_lengths))
118125
offsets = array_lib.array(offsets, dtype="int32")
119-
values = array.reshape(-1, *array.shape[2:])
126+
120127
return values, offsets
121128

122129

0 commit comments

Comments
 (0)