Skip to content

Commit 8816b46

Browse files
Add additional generic types to DataFrame methods (#302)
Adding generic types to a few more methods beyond what was added in #293 by @scarf005 Focusing mostly on adding identity types to methods which I believe don’t change the original type of the dataframe. I added “identity” type signatures to the following methods: > extend, fillNull, filter, interpolate, limit, max, mean, median, min, quantile, rechunk, shiftAndFill, shrinkToFit, slice, sort, std, sum, tail, unique, var, vstack, where, upsample These previously returned `DataFrame<any>`, even when called on a well-typed DataFrame, but now return `DataFrame<T>` (the original type) --- I also added better types for a few slightly more complex ones: - map - improved return type based on the function passed, but unimproved parameter type - nullCount - toRecords - toSeries - for now, returning a broad union type, rather than identifying the specific column by index - withColumn --- Along the way, I added minor fixes for the types of: 1. `pl.intRange` [[1]](890bf21) which had overloads in the wrong order leading to incorrect return types, and 2. the `pl.Series(name, values, dtype)` constructor [[2]](a2635bd), whose strongly-typed overload was failing to apply in simple cases like `pl.Series("index", [0, 1, 2, 3, 4], pl.Int64)` when the input array used `number`s instead of `BigInt`s
1 parent 62a70dc commit 8816b46

File tree

5 files changed

+108
-65
lines changed

5 files changed

+108
-65
lines changed

__tests__/expr.test.ts

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -525,20 +525,22 @@ describe("expr", () => {
525525
a: [1, 2, 3, 3, 3],
526526
b: ["a", "a", "b", "a", "a"],
527527
});
528-
let actual = df.select(pl.len());
529-
let expected = pl.DataFrame({ len: [5] });
528+
const actual = df.select(pl.len());
529+
const expected = pl.DataFrame({ len: [5] });
530530
expect(actual).toFrameEqual(expected);
531531

532-
actual = df.withColumn(pl.len());
533-
expected = df.withColumn(pl.lit(5).alias("len"));
534-
expect(actual).toFrameEqual(expected);
532+
const actual2 = df.withColumn(pl.len());
533+
const expected2 = df.withColumn(pl.lit(5).alias("len"));
534+
expect(actual2).toFrameEqual(expected2);
535535

536-
actual = df.withColumn(pl.intRange(pl.len()).alias("index"));
537-
expected = df.withColumn(pl.Series("index", [0, 1, 2, 3, 4], pl.Int64));
538-
expect(actual).toFrameEqual(expected);
536+
const actual3 = df.withColumn(pl.intRange(pl.len()).alias("index"));
537+
const expected3 = df.withColumn(
538+
pl.Series("index", [0, 1, 2, 3, 4], pl.Int64),
539+
);
540+
expect(actual3).toFrameEqual(expected3);
539541

540-
actual = df.groupBy("b").agg(pl.len());
541-
expect(actual.shape).toEqual({ height: 2, width: 2 });
542+
const actual4 = df.groupBy("b").agg(pl.len());
543+
expect(actual4.shape).toEqual({ height: 2, width: 2 });
542544
});
543545
test("list", () => {
544546
const df = pl.DataFrame({

polars/dataframe.ts

Lines changed: 52 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -466,7 +466,7 @@ export interface DataFrame<T extends Record<string, Series> = any>
466466
467467
* @param other DataFrame to vertically add.
468468
*/
469-
extend(other: DataFrame): DataFrame;
469+
extend(other: DataFrame<T>): DataFrame<T>;
470470
/**
471471
* Fill null/missing values by a filling strategy
472472
*
@@ -480,7 +480,7 @@ export interface DataFrame<T extends Record<string, Series> = any>
480480
* - "one"
481481
* @returns DataFrame with None replaced with the filling strategy.
482482
*/
483-
fillNull(strategy: FillNullStrategy): DataFrame;
483+
fillNull(strategy: FillNullStrategy): DataFrame<T>;
484484
/**
485485
* Filter the rows in the DataFrame based on a predicate expression.
486486
* ___
@@ -519,7 +519,7 @@ export interface DataFrame<T extends Record<string, Series> = any>
519519
* └─────┴─────┴─────┘
520520
* ```
521521
*/
522-
filter(predicate: any): DataFrame;
522+
filter(predicate: any): DataFrame<T>;
523523
/**
524524
* Find the index of a column by name.
525525
* ___
@@ -764,7 +764,7 @@ export interface DataFrame<T extends Record<string, Series> = any>
764764
/**
765765
* Interpolate intermediate values. The interpolation method is linear.
766766
*/
767-
interpolate(): DataFrame;
767+
interpolate(): DataFrame<T>;
768768
/**
769769
* Get a mask of all duplicated rows in this DataFrame.
770770
*/
@@ -937,8 +937,11 @@ export interface DataFrame<T extends Record<string, Series> = any>
937937
* Get first N rows as DataFrame.
938938
* @see {@link head}
939939
*/
940-
limit(length?: number): DataFrame;
941-
map(func: (...args: any[]) => any): any[];
940+
limit(length?: number): DataFrame<T>;
941+
map<ReturnT>(
942+
// TODO: strong types for the mapping function
943+
func: (row: any[], i: number, arr: any[][]) => ReturnT,
944+
): ReturnT[];
942945

943946
/**
944947
* Aggregate the columns of this DataFrame to their maximum value.
@@ -962,8 +965,8 @@ export interface DataFrame<T extends Record<string, Series> = any>
962965
* ╰─────┴─────┴──────╯
963966
* ```
964967
*/
965-
max(): DataFrame;
966-
max(axis: 0): DataFrame;
968+
max(): DataFrame<T>;
969+
max(axis: 0): DataFrame<T>;
967970
max(axis: 1): Series;
968971
/**
969972
* Aggregate the columns of this DataFrame to their mean value.
@@ -972,8 +975,8 @@ export interface DataFrame<T extends Record<string, Series> = any>
972975
* @param axis - either 0 or 1
973976
* @param nullStrategy - this argument is only used if axis == 1
974977
*/
975-
mean(): DataFrame;
976-
mean(axis: 0): DataFrame;
978+
mean(): DataFrame<T>;
979+
mean(axis: 0): DataFrame<T>;
977980
mean(axis: 1): Series;
978981
mean(axis: 1, nullStrategy?: "ignore" | "propagate"): Series;
979982
/**
@@ -997,7 +1000,7 @@ export interface DataFrame<T extends Record<string, Series> = any>
9971000
* ╰─────┴─────┴──────╯
9981001
* ```
9991002
*/
1000-
median(): DataFrame;
1003+
median(): DataFrame<T>;
10011004
/**
10021005
* Unpivot a DataFrame from wide to long format.
10031006
* @deprecated *since 0.13.0* use {@link unpivot}
@@ -1059,8 +1062,8 @@ export interface DataFrame<T extends Record<string, Series> = any>
10591062
* ╰─────┴─────┴──────╯
10601063
* ```
10611064
*/
1062-
min(): DataFrame;
1063-
min(axis: 0): DataFrame;
1065+
min(): DataFrame<T>;
1066+
min(axis: 0): DataFrame<T>;
10641067
min(axis: 1): Series;
10651068
/**
10661069
* Get number of chunks used by the ChunkedArrays of this DataFrame.
@@ -1087,12 +1090,14 @@ export interface DataFrame<T extends Record<string, Series> = any>
10871090
* └─────┴─────┴─────┘
10881091
* ```
10891092
*/
1090-
nullCount(): DataFrame;
1093+
nullCount(): DataFrame<{
1094+
[K in keyof T]: Series<JsToDtype<number>, K & string>;
1095+
}>;
10911096
partitionBy(
10921097
cols: string | string[],
10931098
stable?: boolean,
10941099
includeKey?: boolean,
1095-
): DataFrame[];
1100+
): DataFrame<T>[];
10961101
partitionBy<T>(
10971102
cols: string | string[],
10981103
stable: boolean,
@@ -1210,13 +1215,13 @@ export interface DataFrame<T extends Record<string, Series> = any>
12101215
* ╰─────┴─────┴──────╯
12111216
* ```
12121217
*/
1213-
quantile(quantile: number): DataFrame;
1218+
quantile(quantile: number): DataFrame<T>;
12141219
/**
12151220
* __Rechunk the data in this DataFrame to a contiguous allocation.__
12161221
*
12171222
* This will make sure all subsequent operations have optimal and predictable performance.
12181223
*/
1219-
rechunk(): DataFrame;
1224+
rechunk(): DataFrame<T>;
12201225
/**
12211226
* __Rename column names.__
12221227
* ___
@@ -1443,12 +1448,15 @@ export interface DataFrame<T extends Record<string, Series> = any>
14431448
* └─────┴─────┴─────┘
14441449
* ```
14451450
*/
1446-
shiftAndFill(n: number, fillValue: number): DataFrame;
1447-
shiftAndFill({ n, fillValue }: { n: number; fillValue: number }): DataFrame;
1451+
shiftAndFill(n: number, fillValue: number): DataFrame<T>;
1452+
shiftAndFill({
1453+
n,
1454+
fillValue,
1455+
}: { n: number; fillValue: number }): DataFrame<T>;
14481456
/**
14491457
* Shrink memory usage of this DataFrame to fit the exact capacity needed to hold the data.
14501458
*/
1451-
shrinkToFit(): DataFrame;
1459+
shrinkToFit(): DataFrame<T>;
14521460
shrinkToFit(inPlace: true): void;
14531461
shrinkToFit({ inPlace }: { inPlace: true }): void;
14541462
/**
@@ -1477,8 +1485,8 @@ export interface DataFrame<T extends Record<string, Series> = any>
14771485
* └─────┴─────┴─────┘
14781486
* ```
14791487
*/
1480-
slice({ offset, length }: { offset: number; length: number }): DataFrame;
1481-
slice(offset: number, length: number): DataFrame;
1488+
slice({ offset, length }: { offset: number; length: number }): DataFrame<T>;
1489+
slice(offset: number, length: number): DataFrame<T>;
14821490
/**
14831491
* Sort the DataFrame by column.
14841492
* ___
@@ -1493,7 +1501,7 @@ export interface DataFrame<T extends Record<string, Series> = any>
14931501
descending?: boolean,
14941502
nullsLast?: boolean,
14951503
maintainOrder?: boolean,
1496-
): DataFrame;
1504+
): DataFrame<T>;
14971505
sort({
14981506
by,
14991507
reverse, // deprecated
@@ -1504,7 +1512,7 @@ export interface DataFrame<T extends Record<string, Series> = any>
15041512
reverse?: boolean; // deprecated
15051513
nullsLast?: boolean;
15061514
maintainOrder?: boolean;
1507-
}): DataFrame;
1515+
}): DataFrame<T>;
15081516
sort({
15091517
by,
15101518
descending,
@@ -1514,7 +1522,7 @@ export interface DataFrame<T extends Record<string, Series> = any>
15141522
descending?: boolean;
15151523
nullsLast?: boolean;
15161524
maintainOrder?: boolean;
1517-
}): DataFrame;
1525+
}): DataFrame<T>;
15181526
/**
15191527
* Aggregate the columns of this DataFrame to their standard deviation value.
15201528
* ___
@@ -1536,16 +1544,16 @@ export interface DataFrame<T extends Record<string, Series> = any>
15361544
* ╰─────┴─────┴──────╯
15371545
* ```
15381546
*/
1539-
std(): DataFrame;
1547+
std(): DataFrame<T>;
15401548
/**
15411549
* Aggregate the columns of this DataFrame to their mean value.
15421550
* ___
15431551
*
15441552
* @param axis - either 0 or 1
15451553
* @param nullStrategy - this argument is only used if axis == 1
15461554
*/
1547-
sum(): DataFrame;
1548-
sum(axis: 0): DataFrame;
1555+
sum(): DataFrame<T>;
1556+
sum(axis: 0): DataFrame<T>;
15491557
sum(axis: 1): Series;
15501558
sum(axis: 1, nullStrategy?: "ignore" | "propagate"): Series;
15511559
/**
@@ -1595,7 +1603,7 @@ export interface DataFrame<T extends Record<string, Series> = any>
15951603
* ╰─────────┴─────╯
15961604
* ```
15971605
*/
1598-
tail(length?: number): DataFrame;
1606+
tail(length?: number): DataFrame<T>;
15991607
/**
16001608
* @deprecated *since 0.4.0* use {@link writeCSV}
16011609
* @category Deprecated
@@ -1614,7 +1622,7 @@ export interface DataFrame<T extends Record<string, Series> = any>
16141622
* ```
16151623
* @category IO
16161624
*/
1617-
toRecords(): Record<string, any>[];
1625+
toRecords(): { [K in keyof T]: DTypeToJs<T[K]["dtype"]> | null }[];
16181626

16191627
/**
16201628
* compat with `JSON.stringify`
@@ -1644,7 +1652,7 @@ export interface DataFrame<T extends Record<string, Series> = any>
16441652
* ```
16451653
* @category IO
16461654
*/
1647-
toObject(): { [K in keyof T]: DTypeToJs<T[K]["dtype"]>[] };
1655+
toObject(): { [K in keyof T]: DTypeToJs<T[K]["dtype"] | null>[] };
16481656

16491657
/**
16501658
* @deprecated *since 0.4.0* use {@link writeIPC}
@@ -1656,7 +1664,7 @@ export interface DataFrame<T extends Record<string, Series> = any>
16561664
* @category IO Deprecated
16571665
*/
16581666
toParquet(destination?, options?);
1659-
toSeries(index?: number): Series;
1667+
toSeries(index?: number): T[keyof T];
16601668
toString(): string;
16611669
/**
16621670
* Convert a ``DataFrame`` to a ``Series`` of type ``Struct``
@@ -1768,12 +1776,12 @@ export interface DataFrame<T extends Record<string, Series> = any>
17681776
maintainOrder?: boolean,
17691777
subset?: ColumnSelection,
17701778
keep?: "first" | "last",
1771-
): DataFrame;
1779+
): DataFrame<T>;
17721780
unique(opts: {
17731781
maintainOrder?: boolean;
17741782
subset?: ColumnSelection;
17751783
keep?: "first" | "last";
1776-
}): DataFrame;
1784+
}): DataFrame<T>;
17771785
/**
17781786
Decompose a struct into its fields. The fields will be inserted in to the `DataFrame` on the
17791787
location of the `struct` type.
@@ -1833,7 +1841,7 @@ export interface DataFrame<T extends Record<string, Series> = any>
18331841
* ╰─────┴─────┴──────╯
18341842
* ```
18351843
*/
1836-
var(): DataFrame;
1844+
var(): DataFrame<T>;
18371845
/**
18381846
* Grow this DataFrame vertically by stacking a DataFrame to it.
18391847
* @param df - DataFrame to stack.
@@ -1866,12 +1874,16 @@ export interface DataFrame<T extends Record<string, Series> = any>
18661874
* ╰─────┴─────┴─────╯
18671875
* ```
18681876
*/
1869-
vstack(df: DataFrame): DataFrame;
1877+
vstack(df: DataFrame<T>): DataFrame<T>;
18701878
/**
18711879
* Return a new DataFrame with the column added or replaced.
18721880
* @param column - Series, where the name of the Series refers to the column in the DataFrame.
18731881
*/
1874-
withColumn(column: Series | Expr): DataFrame;
1882+
withColumn<SeriesTypeT extends DataType, SeriesNameT extends string>(
1883+
column: Series<SeriesTypeT, SeriesNameT>,
1884+
): DataFrame<
1885+
Simplify<T & { [K in SeriesNameT]: Series<SeriesTypeT, SeriesNameT> }>
1886+
>;
18751887
withColumn(column: Series | Expr): DataFrame;
18761888
withColumns(...columns: (Expr | Series)[]): DataFrame;
18771889
/**
@@ -1896,7 +1908,7 @@ export interface DataFrame<T extends Record<string, Series> = any>
18961908
*/
18971909
withRowCount(name?: string): DataFrame;
18981910
/** @see {@link filter} */
1899-
where(predicate: any): DataFrame;
1911+
where(predicate: any): DataFrame<T>;
19001912
/**
19011913
Upsample a DataFrame at a regular frequency.
19021914
@@ -1972,13 +1984,13 @@ shape: (7, 3)
19721984
every: string,
19731985
by?: string | string[],
19741986
maintainOrder?: boolean,
1975-
): DataFrame;
1987+
): DataFrame<T>;
19761988
upsample(opts: {
19771989
timeColumn: string;
19781990
every: string;
19791991
by?: string | string[];
19801992
maintainOrder?: boolean;
1981-
}): DataFrame;
1993+
}): DataFrame<T>;
19821994
}
19831995

19841996
function prepareOtherArg(anyValue: any): Series {

polars/datatypes/datatype.ts

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -491,6 +491,20 @@ export type DTypeToJs<T> = T extends DataType.Decimal
491491
: T extends DataType.Utf8
492492
? string
493493
: never;
494+
// some objects can be constructed with a looser JS type than they’d return when converted back to JS
495+
export type DTypeToJsLoose<T> = T extends DataType.Decimal
496+
? number | bigint
497+
: T extends DataType.Float64
498+
? number | bigint
499+
: T extends DataType.Int64
500+
? number | bigint
501+
: T extends DataType.Int32
502+
? number | bigint
503+
: T extends DataType.Bool
504+
? boolean
505+
: T extends DataType.Utf8
506+
? string
507+
: never;
494508
export type DtypeToJsName<T> = T extends DataType.Decimal
495509
? "Decimal"
496510
: T extends DataType.Float64

0 commit comments

Comments
 (0)