Skip to content

Commit 12a2094

Browse files
philipportnermboehm7
authored andcommitted
[SYSTEMDS-3708] Fix raiGroupby permutation-matrix method
Adds a test case with inputs that have multiple groups with varying row counts. This pattern comes from a `lineorder.csv` example dataset that currently causes a runtime exception for the `permutation-matrix` approach but works for the `nested-loop` approach. Why this happened: - `permutation-matrix` approach allocated space assuming every group has `maxRowsInGroup` rows, which is not always the case - groups may have variable sizes resulting in `Y_temp_reduce` having fewer rows than the reshape expects Changes: - correctly pads the matrix in when groups do not all have `maxRowsInGroup` rows - adds testcases that cover this pattern Closes #2288.
1 parent 64455b9 commit 12a2094

File tree

2 files changed

+61
-1
lines changed

2 files changed

+61
-1
lines changed

scripts/builtin/raGroupby.dml

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,22 @@ m_raGroupby = function (Matrix[Double] X, Integer col, String method)
132132
# Set value of final output
133133
Y = matrix(0, rows=numGroups, cols=totalCells)
134134
Y[,1] = key_unique
135-
Y[,2:ncol(Y)] = matrix(Y_temp_reduce, rows=numGroups, cols=totalCells-1)
135+
136+
# The permutation matrix creates a structure where each group's data
137+
# may not fill exactly maxRowsInGroup rows.
138+
# If needed, we need to pad to the expected size first.
139+
expectedRows = numGroups * maxRowsInGroup
140+
actualRows = nrow(Y_temp_reduce)
141+
142+
if(actualRows < expectedRows) {
143+
# Pad Y_temp_reduce with zeros to match expected structure
144+
Y_tmp_padded = matrix(0, rows=expectedRows, cols=ncol(Y_temp_reduce))
145+
Y_tmp_padded[1:actualRows,] = Y_temp_reduce
146+
} else {
147+
Y_tmp_padded = Y_temp_reduce
148+
}
149+
150+
Y[,2:ncol(Y)] = matrix(Y_tmp_padded, rows=numGroups, cols=totalCells-1)
136151
}
137152
}
138153

src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinRaGroupbyTest.java

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,16 @@ public void testRaGroupbyTestwithOneGroup2() {
8080
testRaGroupbyTestwithOneGroup("permutation-matrix");
8181
}
8282

83+
@Test
84+
public void testRaGroupbyTestwithMultipleGroupRows1() {
85+
testRaGroupbyTestwithMultipleGroupRows("nested-loop");
86+
}
87+
88+
@Test
89+
public void testRaGroupbyTestwithMultipleGroupRows2() {
90+
testRaGroupbyTestwithMultipleGroupRows("permutation-matrix");
91+
}
92+
8393
public void testRaGroupbyTest(String method) {
8494
//generate actual dataset and variables
8595
double[][] X = {
@@ -160,6 +170,41 @@ public void testRaGroupbyTestwithOneGroup(String method) {
160170
runRaGroupbyTest(X, select_col, Y, method);
161171
}
162172

173+
public void testRaGroupbyTestwithMultipleGroupRows(String method) {
174+
// Test case with multiple groups having different numbers of rows
175+
// 10 rows x 5 columns, grouping by column 2
176+
// Groups: 1->3 rows, 2->2 rows, 3->2 rows, 4->2 rows, 5->1 row
177+
double[][] X = {
178+
{1, 1, 11, 12, 13},
179+
{1, 2, 21, 22, 23},
180+
{1, 3, 31, 32, 33},
181+
{1, 4, 41, 42, 43},
182+
{2, 1, 14, 15, 16},
183+
{2, 2, 24, 25, 26},
184+
{2, 3, 34, 35, 36},
185+
{2, 4, 44, 45, 46},
186+
{2, 5, 54, 55, 56},
187+
{3, 1, 17, 18, 19}};
188+
int select_col = 2;
189+
190+
// Expected output matrix (grouping by column 2, removing column 2)
191+
// Note: Groups are ordered as they appear in the unique() function output
192+
// Group 1: 3 rows -> [1,11,12,13], [2,14,15,16], [3,17,18,19]
193+
// Group 2: 2 rows -> [1,21,22,23], [2,24,25,26]
194+
// Group 4: 2 rows -> [1,41,42,43], [2,44,45,46]
195+
// Group 5: 1 row -> [2,54,55,56]
196+
// Group 3: 2 rows -> [1,31,32,33], [2,34,35,36]
197+
double[][] Y = {
198+
{1, 1, 11, 12, 13, 2, 14, 15, 16, 3, 17, 18, 19},
199+
{2, 1, 21, 22, 23, 2, 24, 25, 26, 0, 0, 0, 0},
200+
{4, 1, 41, 42, 43, 2, 44, 45, 46, 0, 0, 0, 0},
201+
{5, 2, 54, 55, 56, 0, 0, 0, 0, 0, 0, 0, 0},
202+
{3, 1, 31, 32, 33, 2, 34, 35, 36, 0, 0, 0, 0}
203+
};
204+
205+
runRaGroupbyTest(X, select_col, Y, method);
206+
}
207+
163208
private void runRaGroupbyTest(double [][] X, int col, double [][] Y, String method)
164209
{
165210
ExecMode platformOld = setExecMode(ExecMode.SINGLE_NODE);

0 commit comments

Comments
 (0)