@@ -801,6 +801,8 @@ class MultinomialFunctor {
801801 .Input (" prefix_sum" )
802802 .Output (" out" )
803803 .Build ());
804+ op_npu_ =
805+ CHECK_JUST (one::OpBuilder (" multinomial_with_replacement" ).Input (" x" ).Output (" out" ).Build ());
804806 }
805807
806808 Maybe<Tensor> operator ()(const std::shared_ptr<one::Tensor>& x, const int & num_samples,
@@ -823,10 +825,17 @@ class MultinomialFunctor {
823825 CHECK_OR_RETURN (num_categories <= FLOAT32_MAX_CONSECUTIVE_INT)
824826 << " number of categories cannot exceed 2^24" ;
825827
828+ DeviceType input_device = DeviceType::kCPU ;
829+ if (x->is_global ()) {
830+ JUST (CheckDeviceIdsIsValid (JUST (x->parallel_desc ())));
831+ input_device = JUST (x->parallel_desc ())->device_type ();
832+ } else {
833+ input_device = JUST (x->device ())->enum_type ();
834+ }
826835 // Fast-path for no replacement.
827836 // Reference:
828837 // https://github.com/pytorch/pytorch/issues/11931#issuecomment-625882503
829- if (!replacement) {
838+ if (!replacement && input_device != DeviceType:: kNPU ) {
830839 // The algorithm is from gumbel softmax.
831840 // s = argmax( logp - log(-log(eps)) ) where eps ~ U(0, 1)
832841 // Here we can apply exp to the formula which will not affect result of
@@ -847,6 +856,7 @@ class MultinomialFunctor {
847856 std::shared_ptr<Tensor> result;
848857 if (num_samples == 1 ) {
849858 result = JUST (functional::ArgMax (q, -1 , true , JUST (DType::Get (DataType::kInt64 ))));
859+ } else if (input_device == DeviceType::kNPU ) {
850860 } else {
851861 std::shared_ptr<TensorTuple> temp =
852862 JUST (functional::TopK (q, num_samples, -1 ,
@@ -856,23 +866,18 @@ class MultinomialFunctor {
856866 return result;
857867 }
858868
859- DeviceType input_device = DeviceType::kCPU ;
860- if (x->is_global ()) {
861- JUST (CheckDeviceIdsIsValid (JUST (x->parallel_desc ())));
862- input_device = JUST (x->parallel_desc ())->device_type ();
863- } else {
864- input_device = JUST (x->device ())->enum_type ();
865- }
866869 auto gen = generator.value_or (JUST (one::DefaultAutoGenerator ()));
867870 gen = JUST (GetGeneratorForLazyOrGlobal (gen, LazyMode::is_enabled (), x));
868- auto & attrs = THREAD_CACHED_MUTABLE_ATTR_MAP (" seed" , " num_samples" );
869- attrs.SetAllAttrs (static_cast <int64_t >(gen->current_seed ()), num_samples);
871+ auto & attrs = THREAD_CACHED_MUTABLE_ATTR_MAP (" seed" , " num_samples" , " replacement " );
872+ attrs.SetAllAttrs (static_cast <int64_t >(gen->current_seed ()), num_samples, replacement );
870873
871874 const auto & distribution_state = std::make_shared<DistributionKernelState>(gen);
872875 OpExprInterpContext ctx (attrs, distribution_state);
873876
874877 if (input_device == DeviceType::kCPU ) {
875878 return OpInterpUtil::Dispatch<Tensor>(*op_cpu_, {x}, ctx);
879+ } else if (input_device == DeviceType::kNPU ) {
880+ return OpInterpUtil::Dispatch<Tensor>(*op_npu_, {x}, ctx);
876881 } else {
877882 std::shared_ptr<Tensor> sum_last_dim = JUST (functional::ReduceSum (x, {-1 }, true , NullOpt));
878883 std::shared_ptr<Tensor> norm_dist = JUST (functional::Div (x, sum_last_dim));
@@ -884,6 +889,7 @@ class MultinomialFunctor {
884889 private:
885890 std::shared_ptr<OpExpr> op_cpu_;
886891 std::shared_ptr<OpExpr> op_gpu_;
892+ std::shared_ptr<OpExpr> op_npu_;
887893};
888894
889895} // namespace impl
0 commit comments