Skip to content

Commit 13548dc

Browse files
committed
update groupby.apply to pass in array of series as DataFrame into apply callback function. inspired by #221
1 parent dbceaf5 commit 13548dc

File tree

1 file changed

+24
-10
lines changed

1 file changed

+24
-10
lines changed

danfojs-node/src/core/groupby.js

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import { DataFrame } from "./frame";
22
import { Utils } from "./utils";
33
import { Series } from "./series";
4+
import { concat } from "./concat";
45
const utils = new Utils;
56

67
/**
@@ -267,9 +268,15 @@ export class GroupBy {
267268
function concatPathAndNode(path, node, col_dtype) {
268269
if (Array.isArray(node)) {
269270
if (Array.isArray(node[0])) {
270-
const transposed_node = node[0].map((_, colIndex) => node.map((row) => row[colIndex]));
271-
for (const n_array of transposed_node)
272-
df_data.push(path.concat(n_array));
271+
if (ops != "apply" ) {
272+
const transposed_node = node[0].map((_, colIndex) => node.map((row) => row[colIndex]));
273+
for (const n_array of transposed_node)
274+
df_data.push(path.concat(n_array));
275+
} else {
276+
for (const n_array of node)
277+
df_data.push(path.concat(n_array));
278+
}
279+
273280
} else
274281
df_data.push(path.concat(node));
275282
} else {
@@ -310,13 +317,20 @@ export class GroupBy {
310317
function recursiveCount(sub_df_data, sub_count_group) {
311318
for (const [ key, value ] of Object.entries(sub_df_data)) {
312319
if (Array.isArray(value)) {
313-
sub_count_group[key] = value.map(( callable_value ) => {
314-
const callable_rslt = callable(callable_value);
315-
if ((callable_rslt instanceof DataFrame) || (callable_rslt instanceof Series))
316-
return callable_rslt.values;
317-
else
318-
return callable_rslt;
319-
});
320+
let callable_value;
321+
if (value.length > 1) {
322+
callable_value = concat({ df_list: value, axis: 1 });
323+
} else {
324+
callable_value = value[0];
325+
}
326+
const callable_rslt = callable(callable_value);
327+
if (callable_rslt instanceof DataFrame) {
328+
column = callable_rslt.columns;
329+
sub_count_group[key] = callable_rslt.values;
330+
} else if (callable_rslt instanceof Series) {
331+
sub_count_group[key] = callable_rslt.values;
332+
} else
333+
sub_count_group = callable_rslt;
320334
} else {
321335
sub_count_group[key] = {};
322336
recursiveCount(value, sub_count_group[key]);

0 commit comments

Comments
 (0)