Skip to content

Commit 165a901

Browse files
committed
update groupby.apply to pass in array of series as DataFrame into apply callback function. inspired by #221
1 parent 1eeb6af commit 165a901

File tree

1 file changed

+24
-10
lines changed

1 file changed

+24
-10
lines changed

danfojs-browser/src/core/groupby.js

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import { DataFrame } from "./frame";
22
import { Utils } from "./utils";
33
import { Series } from "./series";
4+
import { concat } from "./concat";
45
const utils = new Utils;
56

67
/**
@@ -271,9 +272,15 @@ export class GroupBy {
271272
function concatPathAndNode(path, node, col_dtype) {
272273
if (Array.isArray(node)) {
273274
if (Array.isArray(node[0])) {
274-
const transposed_node = node[0].map((_, colIndex) => node.map((row) => row[colIndex]));
275-
for (const n_array of transposed_node)
276-
df_data.push(path.concat(n_array));
275+
if (ops != "apply" ) {
276+
const transposed_node = node[0].map((_, colIndex) => node.map((row) => row[colIndex]));
277+
for (const n_array of transposed_node)
278+
df_data.push(path.concat(n_array));
279+
} else {
280+
for (const n_array of node)
281+
df_data.push(path.concat(n_array));
282+
}
283+
277284
} else
278285
df_data.push(path.concat(node));
279286
} else {
@@ -314,13 +321,20 @@ export class GroupBy {
314321
function recursiveCount(sub_df_data, sub_count_group) {
315322
for (const [ key, value ] of Object.entries(sub_df_data)) {
316323
if (Array.isArray(value)) {
317-
sub_count_group[key] = value.map(( callable_value ) => {
318-
const callable_rslt = callable(callable_value);
319-
if ((callable_rslt instanceof DataFrame) || (callable_rslt instanceof Series))
320-
return callable_rslt.values;
321-
else
322-
return callable_rslt;
323-
});
324+
let callable_value;
325+
if (value.length > 1) {
326+
callable_value = concat({ df_list: value, axis: 1 });
327+
} else {
328+
callable_value = value[0];
329+
}
330+
const callable_rslt = callable(callable_value);
331+
if (callable_rslt instanceof DataFrame) {
332+
column = callable_rslt.columns;
333+
sub_count_group[key] = callable_rslt.values;
334+
} else if (callable_rslt instanceof Series) {
335+
sub_count_group[key] = callable_rslt.values;
336+
} else
337+
sub_count_group = callable_rslt;
324338
} else {
325339
sub_count_group[key] = {};
326340
recursiveCount(value, sub_count_group[key]);

0 commit comments

Comments
 (0)