Skip to content

Commit 3c74643

Browse files
Merge pull request #67 from brettshollenberger/rc103
rc103
2 parents 8345f8f + 2ca99d6 commit 3c74643

File tree

15 files changed

+81
-41
lines changed

15 files changed

+81
-41
lines changed

Gemfile.lock

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
PATH
22
remote: .
33
specs:
4-
easy_ml (0.2.0.pre.rc102)
4+
easy_ml (0.2.0.pre.rc103)
55
activerecord
66
activerecord-import (~> 1.8.1)
77
activesupport

app/frontend/components/dataset/PreprocessingConfig.tsx

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1028,7 +1028,6 @@ export function PreprocessingConfig({
10281028
label: strategy.label
10291029
})) || [])
10301030
]}
1031-
options={constants.preprocessing_strategies[selectedType]}
10321031
/>
10331032

10341033
{renderStrategySpecificInfo('training')}

app/models/easy_ml/column.rb

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -522,27 +522,32 @@ def self.from_config(config, dataset, action: :create)
522522
EasyML::Import::Column.from_config(config, dataset, action: action)
523523
end
524524

525-
def cast_statement(df, df_col, expected_dtype)
526-
expected_dtype = expected_dtype.is_a?(Polars::DataType) ? expected_dtype.class : expected_dtype
527-
actual_type = df[df_col].dtype
525+
def cast_statement(series = nil)
526+
expected_dtype = polars_datatype
527+
actual_type = series&.dtype || expected_dtype
528+
529+
return Polars.col(name).cast(expected_dtype).alias(name) if expected_dtype == actual_type
528530

529531
cast_statement = case expected_dtype.to_s
530-
when "Polars::Boolean"
532+
when /Polars::List/
533+
# we should start tracking polars args so we can know what type of list it is
534+
Polars.col(name)
535+
when /Polars::Boolean/
531536
case actual_type.to_s
532-
when "Polars::Boolean"
533-
Polars.col(df_col).cast(expected_dtype)
534-
when "Polars::Utf8", "Polars::Categorical", "Polars::String"
535-
Polars.col(df_col).eq("true").cast(expected_dtype)
536-
when "Polars::Null"
537-
Polars.col(df_col)
537+
when /Polars::Boolean/, /Polars::Int/
538+
Polars.col(name).cast(expected_dtype)
539+
when /Polars::Utf/, /Polars::Categorical/, /Polars::String/
540+
Polars.col(name).eq("true").cast(expected_dtype)
541+
when /Polars::Null/
542+
Polars.col(name)
538543
else
539-
raise "Unexpected dtype: #{actual_type} for column: #{df_col}"
544+
raise "Unexpected dtype: #{actual_type} for column: #{name}"
540545
end
541546
else
542-
Polars.col(df_col).cast(expected_dtype)
547+
Polars.col(name).cast(expected_dtype, strict: false)
543548
end
544549

545-
cast_statement.alias(df_col)
550+
cast_statement.alias(name)
546551
end
547552

548553
def cast(value)

app/models/easy_ml/column_list.rb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,10 +101,10 @@ def apply_cast(df)
101101
end
102102
cast_statements = (df.columns & schema.keys.map(&:to_s)).map do |df_col|
103103
db_col = column_index[df_col]
104-
expected_dtype = schema[df_col.to_sym]
105-
db_col.cast_statement(df, df_col, expected_dtype)
104+
db_col.cast_statement(df[df_col])
106105
end
107106
df = df.with_columns(cast_statements)
107+
df
108108
end
109109

110110
def cast(processed_or_raw)

app/models/easy_ml/dataset/learner/lazy.rb

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,15 +28,15 @@ def run_queries(split, type)
2828
)
2929
.select(queries).collect
3030
rescue => e
31-
problematic_query = queries.detect {
31+
problematic_queries = queries.select { |query|
3232
begin
33-
dataset.send(type).send(split, all_columns: true, lazy: true).select(queries).collect
33+
dataset.send(type).send(split, all_columns: true, lazy: true).select([query]).collect
3434
false
3535
rescue => e
3636
true
3737
end
3838
}
39-
raise "Query failed for column #{problematic_query}, likely wrong datatype"
39+
raise "Query failed for queries... likely due to wrong column datatype: #{problematic_queries.join("\n")}"
4040
end
4141
end
4242

@@ -64,4 +64,4 @@ def build_queries(split, type)
6464
end
6565
end
6666
end
67-
end
67+
end

app/models/easy_ml/dataset/learner/lazy/datetime.rb

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@ def full_dataset_query
1010
end
1111

1212
def unique_count
13-
Polars.col(column.name).n_unique.alias("#{column.name}__unique_count")
13+
Polars.col(column.name)
14+
.cast(column.polars_datatype)
15+
.n_unique.alias("#{column.name}__unique_count")
1416
end
1517
end
1618
end

app/models/easy_ml/dataset/learner/lazy/numeric.rb

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,30 @@ class Lazy
55
class Numeric < Query
66
def train_query
77
super.concat([
8-
Polars.col(column.name).mean.alias("#{column.name}__mean"),
9-
Polars.col(column.name).median.alias("#{column.name}__median"),
10-
Polars.col(column.name).min.alias("#{column.name}__min"),
11-
Polars.col(column.name).max.alias("#{column.name}__max"),
12-
Polars.col(column.name).std.alias("#{column.name}__std"),
8+
Polars.col(column.name)
9+
.cast(column.polars_datatype)
10+
.mean
11+
.alias("#{column.name}__mean"),
12+
13+
Polars.col(column.name)
14+
.cast(column.polars_datatype)
15+
.median
16+
.alias("#{column.name}__median"),
17+
18+
Polars.col(column.name)
19+
.cast(column.polars_datatype)
20+
.min
21+
.alias("#{column.name}__min"),
22+
23+
Polars.col(column.name)
24+
.cast(column.polars_datatype)
25+
.max
26+
.alias("#{column.name}__max"),
27+
28+
Polars.col(column.name)
29+
.cast(column.polars_datatype)
30+
.std
31+
.alias("#{column.name}__std"),
1332
])
1433
end
1534
end

app/models/easy_ml/dataset/learner/lazy/query.rb

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -44,25 +44,37 @@ def train_query
4444
end
4545

4646
def null_count
47-
Polars.col(column.name).null_count.alias("#{column.name}__null_count")
47+
Polars.col(column.name)
48+
.cast(column.polars_datatype)
49+
.null_count
50+
.alias("#{column.name}__null_count")
4851
end
4952

5053
def num_rows
51-
Polars.col(column.name).len.alias("#{column.name}__num_rows")
54+
Polars.col(column.name)
55+
.cast(column.polars_datatype)
56+
.len
57+
.alias("#{column.name}__num_rows")
5258
end
5359

5460
def most_frequent_value
55-
Polars.col(column.name).filter(Polars.col(column.name).is_not_null).mode.first.alias("#{column.name}__most_frequent_value")
61+
Polars.col(column.name)
62+
.cast(column.polars_datatype)
63+
.filter(Polars.col(column.name).is_not_null)
64+
.mode
65+
.first
66+
.alias("#{column.name}__most_frequent_value")
5667
end
5768

5869
def last_value
5970
return unless dataset.date_column.present?
6071

6172
Polars.col(column.name)
62-
.sort_by(dataset.date_column.name, reverse: true, nulls_last: true)
63-
.filter(Polars.col(column.name).is_not_null)
64-
.first
65-
.alias("#{column.name}__last_value")
73+
.cast(column.polars_datatype)
74+
.sort_by(dataset.date_column.name, reverse: true, nulls_last: true)
75+
.filter(Polars.col(column.name).is_not_null)
76+
.first
77+
.alias("#{column.name}__last_value")
6678
end
6779
end
6880
end

app/models/easy_ml/dataset/learner/lazy/string.rb

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,10 @@ def full_dataset_query
1010
end
1111

1212
def unique_count
13-
Polars.col(column.name).cast(:str).n_unique.alias("#{column.name}__unique_count")
13+
Polars.col(column.name)
14+
.cast(Polars::String)
15+
.n_unique
16+
.alias("#{column.name}__unique_count")
1417
end
1518
end
1619
end

easy_ml-0.2.0.pre.rc103.gem

933 KB
Binary file not shown.

0 commit comments

Comments
 (0)