@@ -476,8 +476,8 @@ defmodule Scholar.NaiveBayes.Categorical do
476476 } ,
477477 x
478478 ) do
479- { _ , _ , _ , jll } =
480- while { i = 0 , feature_log_probability , x , jll = Nx . broadcast ( 0.0 , Nx . shape ( x ) ) } ,
479+ { _ , jll } =
480+ while { { i = 0 , feature_log_probability , x } , jll = Nx . broadcast ( 0.0 , Nx . shape ( x ) ) } ,
481481 i < Nx . axis_size ( x , 1 ) do
482482 indices = Nx . slice_along_axis ( x , i , 1 , axis: 1 ) |> Nx . squeeze ( axes: [ 1 ] )
483483
@@ -488,33 +488,32 @@ defmodule Scholar.NaiveBayes.Categorical do
488488 |> Nx . transpose ( )
489489 |> Nx . add ( jll )
490490
491- { i + 1 , feature_log_probability , x , jll }
491+ { { i + 1 , feature_log_probability , x } , jll }
492492 end
493493
494- total_jll = jll + class_log_priors
495- total_jll
494+ jll + class_log_priors
496495 end
497496
498497 defnp count ( x , y_weighted , num_features , num_classes , num_categories , num_samples ) do
499498 class_count = Nx . sum ( y_weighted , axes: [ 0 ] )
500499
501500 feature_count = Nx . broadcast ( 0.0 , { num_features , num_classes , num_categories } )
502501
503- { _ , _ , _ , _ , feature_count } =
504- while { i = 0 , x , y_weighted , num_features , feature_count } , i < num_samples do
505- { _ , _ , _ , _ , _ , feature_count } =
506- while { j = 0 , x , y_weighted , i , num_features , feature_count } , j < num_features do
502+ { _ , feature_count } =
503+ while { { i = 0 , x , y_weighted , num_features } , feature_count } , i < num_samples do
504+ { _ , feature_count } =
505+ while { { j = 0 , x , y_weighted , i , num_features } , feature_count } , j < num_features do
507506 category_value = x [ i ] [ j ]
508507 class_label = Nx . argmax ( y_weighted [ i ] )
509508
510509 index = Nx . stack ( [ j , class_label , category_value ] )
511510
512511 feature_count = Nx . indexed_add ( feature_count , index , y_weighted [ i ] [ class_label ] )
513512
514- { j + 1 , x , y_weighted , i , num_features , feature_count }
513+ { { j + 1 , x , y_weighted , i , num_features } , feature_count }
515514 end
516515
517- { i + 1 , x , y_weighted , num_features , feature_count }
516+ { { i + 1 , x , y_weighted , num_features } , feature_count }
518517 end
519518
520519 { class_count , feature_count }
@@ -523,8 +522,8 @@ defmodule Scholar.NaiveBayes.Categorical do
523522 defnp compute_feature_log_probability ( feature_count , alpha , min_categories ) do
524523 feature_log_probability = Nx . broadcast ( 0.0 , Nx . shape ( feature_count ) )
525524
526- { _ , _ , _ , _ , feature_log_probability } =
527- while { i = 0 , feature_count , alpha , min_categories , feature_log_probability } ,
525+ { _ , feature_log_probability } =
526+ while { { i = 0 , feature_count , alpha , min_categories } , feature_log_probability } ,
528527 i < Nx . axis_size ( feature_count , 0 ) do
529528 smoothed_class_count =
530529 Nx . sum ( feature_count [ i ] , axes: [ 1 ] )
@@ -548,7 +547,7 @@ defmodule Scholar.NaiveBayes.Categorical do
548547 feature_log_probability =
549548 Nx . put_slice ( feature_log_probability , [ i , 0 , 0 ] , smoothed_cat_count )
550549
551- { i + 1 , feature_count , alpha , min_categories , feature_log_probability }
550+ { { i + 1 , feature_count , alpha , min_categories } , feature_log_probability }
552551 end
553552
554553 feature_log_probability
0 commit comments