Skip to content

Commit 5458bae

Browse files
committed
Use mul_add for numeric ops and minor cleanups
Replace many manual multiplications/additions with f64::mul_add for improved precision/performance and reduce unnecessary cloning and duplication. Changes: use mul_add in moarstats.rs and stats.rs for variance, sums of squares, mutual information and entropy calculations; deduplicate and move format_output_str in describegpt.rs and avoid an extra clone when building PrepareContextOutput; remove an unnecessary pos.clone() in join.rs when seeking the second CSV reader. These edits improve numeric stability and remove small inefficiencies across the codebase.
1 parent 865aeb4 commit 5458bae

File tree

4 files changed

+36
-34
lines changed

4 files changed

+36
-34
lines changed

src/cmd/describegpt.rs

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2794,6 +2794,16 @@ fn process_phase_output(
27942794
base_url: &str,
27952795
output_format: OutputFormat,
27962796
) -> CliResult<()> {
2797+
// For non-dictionary types, format output
2798+
fn format_output_str(str: &str) -> String {
2799+
str.replace("\\n", "\n")
2800+
.replace("\\t", "\t")
2801+
.replace("\\\"", "\"")
2802+
.replace("\\'", "'")
2803+
.replace("\\`", "`")
2804+
+ "\n\n"
2805+
}
2806+
27972807
// Skip outputting dictionary when using --prompt (but still generate it for context)
27982808
if kind == PromptType::Dictionary && args.flag_prompt.is_some() {
27992809
let (stats_records, ordered_col_names) = parse_stats_csv(&analysis_results.stats)?;
@@ -2921,16 +2931,6 @@ fn process_phase_output(
29212931
return Ok(());
29222932
}
29232933

2924-
// For non-dictionary types, format output
2925-
fn format_output_str(str: &str) -> String {
2926-
str.replace("\\n", "\n")
2927-
.replace("\\t", "\t")
2928-
.replace("\\\"", "\"")
2929-
.replace("\\'", "'")
2930-
.replace("\\`", "`")
2931-
+ "\n\n"
2932-
}
2933-
29342934
let is_sql_response = kind == PromptType::Prompt
29352935
&& args.flag_sql_results.is_some()
29362936
&& completion_response.response.contains("```sql");
@@ -5176,7 +5176,7 @@ pub fn run(argv: &[&str]) -> CliResult<()> {
51765176

51775177
let output = PrepareContextOutput {
51785178
phases,
5179-
analysis_results: analysis_results.clone(),
5179+
analysis_results,
51805180
model,
51815181
max_tokens: args.flag_max_tokens,
51825182
};

src/cmd/join.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -375,7 +375,7 @@ impl<R: io::Read + io::Seek, W: io::Write> IoState<R, W> {
375375
let mut row1 = csv::ByteRecord::new();
376376
let rdr2_has_headers = self.rdr2.has_headers();
377377
while self.rdr1.read_byte_record(&mut row1)? {
378-
self.rdr2.seek(pos.clone())?;
378+
self.rdr2.seek(pos)?;
379379
if rdr2_has_headers {
380380
// Read and skip the header row, since CSV readers disable
381381
// the header skipping logic after being seeked.

src/cmd/moarstats.rs

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1399,9 +1399,9 @@ fn update_correlation_state(state: &mut CorrelationState, x: f64, y: f64) {
13991399
let delta_x_new = x - state.mean_x;
14001400
let delta_y_new = y - state.mean_y;
14011401

1402-
state.m2_x += delta_x * delta_x_new;
1403-
state.m2_y += delta_y * delta_y_new;
1404-
state.cxy += delta_x * delta_y_new;
1402+
state.m2_x = delta_x.mul_add(delta_x_new, state.m2_x);
1403+
state.m2_y = delta_y.mul_add(delta_y_new, state.m2_y);
1404+
state.cxy = delta_x.mul_add(delta_y_new, state.cxy);
14051405
}
14061406

14071407
/// Merge two correlation states (for aggregating across chunks)
@@ -1734,7 +1734,7 @@ fn compute_mutual_information_from_counts(
17341734
let p_y = y_counts.get(y_val).copied().unwrap_or(0) as f64 / total_f64;
17351735

17361736
if p_x > 0.0 && p_y > 0.0 && p_xy > 0.0 {
1737-
mi += p_xy * (p_xy / (p_x * p_y)).log2();
1737+
mi = p_xy.mul_add((p_xy / (p_x * p_y)).log2(), mi);
17381738
}
17391739
}
17401740

@@ -1756,7 +1756,7 @@ fn compute_entropy_from_counts(counts: &HashMap<String, u64>, total: u64) -> Opt
17561756
for count in counts.values() {
17571757
if *count > 0 {
17581758
let p = *count as f64 / total_f64;
1759-
entropy -= p * p.log2();
1759+
entropy = p.mul_add(-p.log2(), entropy);
17601760
}
17611761
}
17621762

@@ -1877,7 +1877,8 @@ where
18771877
.max_winsorized
18781878
.map_or(winsorized_val, |m| m.max(winsorized_val)),
18791879
);
1880-
stats.sum_squares_winsorized += winsorized_val * winsorized_val;
1880+
stats.sum_squares_winsorized =
1881+
winsorized_val.mul_add(winsorized_val, stats.sum_squares_winsorized);
18811882

18821883
// For trimmed mean, only include values within thresholds
18831884
if val >= field_info.lower_threshold && val <= field_info.upper_threshold {
@@ -1886,42 +1887,42 @@ where
18861887
// Track trimmed min/max and sum of squares
18871888
stats.min_trimmed = Some(stats.min_trimmed.map_or(val, |m| m.min(val)));
18881889
stats.max_trimmed = Some(stats.max_trimmed.map_or(val, |m| m.max(val)));
1889-
stats.sum_squares_trimmed += val * val;
1890+
stats.sum_squares_trimmed = val.mul_add(val, stats.sum_squares_trimmed);
18901891
}
18911892

18921893
// Count outliers and track statistics based on fence comparisons
18931894
if val < field_info.lower_outer {
18941895
stats.counts[0] += 1; // extreme_lower
18951896
stats.counts[5] += 1; // total
18961897
stats.sum_outliers += val;
1897-
stats.sum_squares_outliers += val * val;
1898+
stats.sum_squares_outliers = val.mul_add(val, stats.sum_squares_outliers);
18981899
stats.min_outliers = Some(stats.min_outliers.map_or(val, |m| m.min(val)));
18991900
stats.max_outliers = Some(stats.max_outliers.map_or(val, |m| m.max(val)));
19001901
} else if val < field_info.lower_inner {
19011902
stats.counts[1] += 1; // mild_lower
19021903
stats.counts[5] += 1; // total
19031904
stats.sum_outliers += val;
1904-
stats.sum_squares_outliers += val * val;
1905+
stats.sum_squares_outliers = val.mul_add(val, stats.sum_squares_outliers);
19051906
stats.min_outliers = Some(stats.min_outliers.map_or(val, |m| m.min(val)));
19061907
stats.max_outliers = Some(stats.max_outliers.map_or(val, |m| m.max(val)));
19071908
} else if val <= field_info.upper_inner {
19081909
stats.counts[2] += 1; // normal
19091910
stats.sum_normal += val;
1910-
stats.sum_squares_normal += val * val;
1911+
stats.sum_squares_normal = val.mul_add(val, stats.sum_squares_normal);
19111912
stats.min_normal = Some(stats.min_normal.map_or(val, |m| m.min(val)));
19121913
stats.max_normal = Some(stats.max_normal.map_or(val, |m| m.max(val)));
19131914
} else if val <= field_info.upper_outer {
19141915
stats.counts[3] += 1; // mild_upper
19151916
stats.counts[5] += 1; // total
19161917
stats.sum_outliers += val;
1917-
stats.sum_squares_outliers += val * val;
1918+
stats.sum_squares_outliers = val.mul_add(val, stats.sum_squares_outliers);
19181919
stats.min_outliers = Some(stats.min_outliers.map_or(val, |m| m.min(val)));
19191920
stats.max_outliers = Some(stats.max_outliers.map_or(val, |m| m.max(val)));
19201921
} else {
19211922
stats.counts[4] += 1; // extreme_upper
19221923
stats.counts[5] += 1; // total
19231924
stats.sum_outliers += val;
1924-
stats.sum_squares_outliers += val * val;
1925+
stats.sum_squares_outliers = val.mul_add(val, stats.sum_squares_outliers);
19251926
stats.min_outliers = Some(stats.min_outliers.map_or(val, |m| m.min(val)));
19261927
stats.max_outliers = Some(stats.max_outliers.map_or(val, |m| m.max(val)));
19271928
}
@@ -2332,7 +2333,8 @@ fn count_all_outliers_from_reader(
23322333
.max_winsorized
23332334
.map_or(winsorized_val, |m| m.max(winsorized_val)),
23342335
);
2335-
stats.sum_squares_winsorized += winsorized_val * winsorized_val;
2336+
stats.sum_squares_winsorized =
2337+
winsorized_val.mul_add(winsorized_val, stats.sum_squares_winsorized);
23362338

23372339
// For trimmed mean, only include values within thresholds
23382340
if val >= field_info.lower_threshold && val <= field_info.upper_threshold {
@@ -2341,42 +2343,42 @@ fn count_all_outliers_from_reader(
23412343
// Track trimmed min/max and sum of squares
23422344
stats.min_trimmed = Some(stats.min_trimmed.map_or(val, |m| m.min(val)));
23432345
stats.max_trimmed = Some(stats.max_trimmed.map_or(val, |m| m.max(val)));
2344-
stats.sum_squares_trimmed += val * val;
2346+
stats.sum_squares_trimmed = val.mul_add(val, stats.sum_squares_trimmed);
23452347
}
23462348

23472349
// Count outliers and track statistics based on fence comparisons
23482350
if val < field_info.lower_outer {
23492351
stats.counts[0] += 1; // extreme_lower
23502352
stats.counts[5] += 1; // total
23512353
stats.sum_outliers += val;
2352-
stats.sum_squares_outliers += val * val;
2354+
stats.sum_squares_outliers = val.mul_add(val, stats.sum_squares_outliers);
23532355
stats.min_outliers = Some(stats.min_outliers.map_or(val, |m| m.min(val)));
23542356
stats.max_outliers = Some(stats.max_outliers.map_or(val, |m| m.max(val)));
23552357
} else if val < field_info.lower_inner {
23562358
stats.counts[1] += 1; // mild_lower
23572359
stats.counts[5] += 1; // total
23582360
stats.sum_outliers += val;
2359-
stats.sum_squares_outliers += val * val;
2361+
stats.sum_squares_outliers = val.mul_add(val, stats.sum_squares_outliers);
23602362
stats.min_outliers = Some(stats.min_outliers.map_or(val, |m| m.min(val)));
23612363
stats.max_outliers = Some(stats.max_outliers.map_or(val, |m| m.max(val)));
23622364
} else if val <= field_info.upper_inner {
23632365
stats.counts[2] += 1; // normal
23642366
stats.sum_normal += val;
2365-
stats.sum_squares_normal += val * val;
2367+
stats.sum_squares_normal = val.mul_add(val, stats.sum_squares_normal);
23662368
stats.min_normal = Some(stats.min_normal.map_or(val, |m| m.min(val)));
23672369
stats.max_normal = Some(stats.max_normal.map_or(val, |m| m.max(val)));
23682370
} else if val <= field_info.upper_outer {
23692371
stats.counts[3] += 1; // mild_upper
23702372
stats.counts[5] += 1; // total
23712373
stats.sum_outliers += val;
2372-
stats.sum_squares_outliers += val * val;
2374+
stats.sum_squares_outliers = val.mul_add(val, stats.sum_squares_outliers);
23732375
stats.min_outliers = Some(stats.min_outliers.map_or(val, |m| m.min(val)));
23742376
stats.max_outliers = Some(stats.max_outliers.map_or(val, |m| m.max(val)));
23752377
} else {
23762378
stats.counts[4] += 1; // extreme_upper
23772379
stats.counts[5] += 1; // total
23782380
stats.sum_outliers += val;
2379-
stats.sum_squares_outliers += val * val;
2381+
stats.sum_squares_outliers = val.mul_add(val, stats.sum_squares_outliers);
23802382
stats.min_outliers = Some(stats.min_outliers.map_or(val, |m| m.min(val)));
23812383
stats.max_outliers = Some(stats.max_outliers.map_or(val, |m| m.max(val)));
23822384
}

src/cmd/stats.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2663,13 +2663,13 @@ impl WeightedOnlineStats {
26632663
self.sum_weights += w;
26642664

26652665
let delta = x - self.weighted_mean;
2666-
self.weighted_mean += (w / self.sum_weights) * delta;
2666+
self.weighted_mean = (w / self.sum_weights).mul_add(delta, self.weighted_mean);
26672667
let delta2 = x - self.weighted_mean;
2668-
self.sum_squared_diffs += w * delta * delta2;
2668+
self.sum_squared_diffs = (w * delta).mul_add(delta2, self.sum_squared_diffs);
26692669

26702670
// Accumulate weighted logs for geometric mean (only if x > 0)
26712671
if x > 0.0 {
2672-
self.sum_weighted_logs += w * x.ln();
2672+
self.sum_weighted_logs = w.mul_add(x.ln(), self.sum_weighted_logs);
26732673
self.sum_weights_positive += w;
26742674
}
26752675

0 commit comments

Comments
 (0)