Skip to content

Commit e38812e

Browse files
committed
Apply comment suggestions
1 parent 05f5d26 commit e38812e

File tree

1 file changed

+13
-14
lines changed

1 file changed

+13
-14
lines changed

lib/scholar/naive_bayes/categorical.ex

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -476,8 +476,8 @@ defmodule Scholar.NaiveBayes.Categorical do
476476
},
477477
x
478478
) do
479-
{_, _, _, jll} =
480-
while {i = 0, feature_log_probability, x, jll = Nx.broadcast(0.0, Nx.shape(x))},
479+
{_, jll} =
480+
while {{i = 0, feature_log_probability, x}, jll = Nx.broadcast(0.0, Nx.shape(x))},
481481
i < Nx.axis_size(x, 1) do
482482
indices = Nx.slice_along_axis(x, i, 1, axis: 1) |> Nx.squeeze(axes: [1])
483483

@@ -488,33 +488,32 @@ defmodule Scholar.NaiveBayes.Categorical do
488488
|> Nx.transpose()
489489
|> Nx.add(jll)
490490

491-
{i + 1, feature_log_probability, x, jll}
491+
{{i + 1, feature_log_probability, x}, jll}
492492
end
493493

494-
total_jll = jll + class_log_priors
495-
total_jll
494+
jll + class_log_priors
496495
end
497496

498497
defnp count(x, y_weighted, num_features, num_classes, num_categories, num_samples) do
499498
class_count = Nx.sum(y_weighted, axes: [0])
500499

501500
feature_count = Nx.broadcast(0.0, {num_features, num_classes, num_categories})
502501

503-
{_, _, _, _, feature_count} =
504-
while {i = 0, x, y_weighted, num_features, feature_count}, i < num_samples do
505-
{_, _, _, _, _, feature_count} =
506-
while {j = 0, x, y_weighted, i, num_features, feature_count}, j < num_features do
502+
{_, feature_count} =
503+
while {{i = 0, x, y_weighted, num_features}, feature_count}, i < num_samples do
504+
{_, feature_count} =
505+
while {{j = 0, x, y_weighted, i, num_features}, feature_count}, j < num_features do
507506
category_value = x[i][j]
508507
class_label = Nx.argmax(y_weighted[i])
509508

510509
index = Nx.stack([j, class_label, category_value])
511510

512511
feature_count = Nx.indexed_add(feature_count, index, y_weighted[i][class_label])
513512

514-
{j + 1, x, y_weighted, i, num_features, feature_count}
513+
{{j + 1, x, y_weighted, i, num_features}, feature_count}
515514
end
516515

517-
{i + 1, x, y_weighted, num_features, feature_count}
516+
{{i + 1, x, y_weighted, num_features}, feature_count}
518517
end
519518

520519
{class_count, feature_count}
@@ -523,8 +522,8 @@ defmodule Scholar.NaiveBayes.Categorical do
523522
defnp compute_feature_log_probability(feature_count, alpha, min_categories) do
524523
feature_log_probability = Nx.broadcast(0.0, Nx.shape(feature_count))
525524

526-
{_, _, _, _, feature_log_probability} =
527-
while {i = 0, feature_count, alpha, min_categories, feature_log_probability},
525+
{_, feature_log_probability} =
526+
while {{i = 0, feature_count, alpha, min_categories}, feature_log_probability},
528527
i < Nx.axis_size(feature_count, 0) do
529528
smoothed_class_count =
530529
Nx.sum(feature_count[i], axes: [1])
@@ -548,7 +547,7 @@ defmodule Scholar.NaiveBayes.Categorical do
548547
feature_log_probability =
549548
Nx.put_slice(feature_log_probability, [i, 0, 0], smoothed_cat_count)
550549

551-
{i + 1, feature_count, alpha, min_categories, feature_log_probability}
550+
{{i + 1, feature_count, alpha, min_categories}, feature_log_probability}
552551
end
553552

554553
feature_log_probability

0 commit comments

Comments
 (0)