Skip to content

Commit dbceaf5

Browse files
committed
update groupby.apply to pass in array of series as DataFrame into apply callback function. inspired by #221
1 parent 165a901 commit dbceaf5

File tree

1 file changed

+27
-6
lines changed

1 file changed

+27
-6
lines changed

danfojs-node/dist/core/groupby.js

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ var _utils = require("./utils");
1111

1212
var _series = require("./series");
1313

14+
var _concat = require("./concat");
15+
1416
const utils = new _utils.Utils();
1517

1618
class GroupBy {
@@ -229,9 +231,13 @@ class GroupBy {
229231
function concatPathAndNode(path, node, col_dtype) {
230232
if (Array.isArray(node)) {
231233
if (Array.isArray(node[0])) {
232-
const transposed_node = node[0].map((_, colIndex) => node.map(row => row[colIndex]));
234+
if (ops != "apply") {
235+
const transposed_node = node[0].map((_, colIndex) => node.map(row => row[colIndex]));
233236

234-
for (const n_array of transposed_node) df_data.push(path.concat(n_array));
237+
for (const n_array of transposed_node) df_data.push(path.concat(n_array));
238+
} else {
239+
for (const n_array of node) df_data.push(path.concat(n_array));
240+
}
235241
} else df_data.push(path.concat(node));
236242
} else {
237243
for (const [k, child] of Object.entries(node)) {
@@ -274,10 +280,25 @@ class GroupBy {
274280
function recursiveCount(sub_df_data, sub_count_group) {
275281
for (const [key, value] of Object.entries(sub_df_data)) {
276282
if (Array.isArray(value)) {
277-
sub_count_group[key] = value.map(callable_value => {
278-
const callable_rslt = callable(callable_value);
279-
if (callable_rslt instanceof _frame.DataFrame || callable_rslt instanceof _series.Series) return callable_rslt.values;else return callable_rslt;
280-
});
283+
let callable_value;
284+
285+
if (value.length > 1) {
286+
callable_value = (0, _concat.concat)({
287+
df_list: value,
288+
axis: 1
289+
});
290+
} else {
291+
callable_value = value[0];
292+
}
293+
294+
const callable_rslt = callable(callable_value);
295+
296+
if (callable_rslt instanceof _frame.DataFrame) {
297+
column = callable_rslt.columns;
298+
sub_count_group[key] = callable_rslt.values;
299+
} else if (callable_rslt instanceof _series.Series) {
300+
sub_count_group[key] = callable_rslt.values;
301+
} else sub_count_group = callable_rslt;
281302
} else {
282303
sub_count_group[key] = {};
283304
recursiveCount(value, sub_count_group[key]);

0 commit comments

Comments
 (0)