Skip to content

Commit 25c8978

Browse files
support npu multinomial (#10668)
Co-authored-by: oneflow-ci-bot <ci-bot@oneflow.org>
1 parent f485d7e commit 25c8978

File tree

2 files changed

+18
-11
lines changed

2 files changed

+18
-11
lines changed

oneflow/core/functional/impl/random_functor.cpp

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -801,6 +801,8 @@ class MultinomialFunctor {
801801
.Input("prefix_sum")
802802
.Output("out")
803803
.Build());
804+
op_npu_ =
805+
CHECK_JUST(one::OpBuilder("multinomial_with_replacement").Input("x").Output("out").Build());
804806
}
805807

806808
Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, const int& num_samples,
@@ -823,10 +825,17 @@ class MultinomialFunctor {
823825
CHECK_OR_RETURN(num_categories <= FLOAT32_MAX_CONSECUTIVE_INT)
824826
<< "number of categories cannot exceed 2^24";
825827

828+
DeviceType input_device = DeviceType::kCPU;
829+
if (x->is_global()) {
830+
JUST(CheckDeviceIdsIsValid(JUST(x->parallel_desc())));
831+
input_device = JUST(x->parallel_desc())->device_type();
832+
} else {
833+
input_device = JUST(x->device())->enum_type();
834+
}
826835
// Fast-path for no replacement.
827836
// Reference:
828837
// https://github.com/pytorch/pytorch/issues/11931#issuecomment-625882503
829-
if (!replacement) {
838+
if (!replacement && input_device != DeviceType::kNPU) {
830839
// The algorithm is from gumbel softmax.
831840
// s = argmax( logp - log(-log(eps)) ) where eps ~ U(0, 1)
832841
// Here we can apply exp to the formula which will not affect result of
@@ -847,6 +856,7 @@ class MultinomialFunctor {
847856
std::shared_ptr<Tensor> result;
848857
if (num_samples == 1) {
849858
result = JUST(functional::ArgMax(q, -1, true, JUST(DType::Get(DataType::kInt64))));
859+
} else if (input_device == DeviceType::kNPU) {
850860
} else {
851861
std::shared_ptr<TensorTuple> temp =
852862
JUST(functional::TopK(q, num_samples, -1,
@@ -856,23 +866,18 @@ class MultinomialFunctor {
856866
return result;
857867
}
858868

859-
DeviceType input_device = DeviceType::kCPU;
860-
if (x->is_global()) {
861-
JUST(CheckDeviceIdsIsValid(JUST(x->parallel_desc())));
862-
input_device = JUST(x->parallel_desc())->device_type();
863-
} else {
864-
input_device = JUST(x->device())->enum_type();
865-
}
866869
auto gen = generator.value_or(JUST(one::DefaultAutoGenerator()));
867870
gen = JUST(GetGeneratorForLazyOrGlobal(gen, LazyMode::is_enabled(), x));
868-
auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("seed", "num_samples");
869-
attrs.SetAllAttrs(static_cast<int64_t>(gen->current_seed()), num_samples);
871+
auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("seed", "num_samples", "replacement");
872+
attrs.SetAllAttrs(static_cast<int64_t>(gen->current_seed()), num_samples, replacement);
870873

871874
const auto& distribution_state = std::make_shared<DistributionKernelState>(gen);
872875
OpExprInterpContext ctx(attrs, distribution_state);
873876

874877
if (input_device == DeviceType::kCPU) {
875878
return OpInterpUtil::Dispatch<Tensor>(*op_cpu_, {x}, ctx);
879+
} else if (input_device == DeviceType::kNPU) {
880+
return OpInterpUtil::Dispatch<Tensor>(*op_npu_, {x}, ctx);
876881
} else {
877882
std::shared_ptr<Tensor> sum_last_dim = JUST(functional::ReduceSum(x, {-1}, true, NullOpt));
878883
std::shared_ptr<Tensor> norm_dist = JUST(functional::Div(x, sum_last_dim));
@@ -884,6 +889,7 @@ class MultinomialFunctor {
884889
private:
885890
std::shared_ptr<OpExpr> op_cpu_;
886891
std::shared_ptr<OpExpr> op_gpu_;
892+
std::shared_ptr<OpExpr> op_npu_;
887893
};
888894

889895
} // namespace impl

oneflow/ir/include/OneFlow/OneFlowUserOps.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6355,7 +6355,8 @@ def OneFlow_MultinomialWithReplacementOp : OneFlow_BaseOp<"multinomial_with_repl
63556355
);
63566356
let attrs = (ins
63576357
DefaultValuedAttr<SI64Attr, "0">:$seed,
6358-
DefaultValuedAttr<SI32Attr, "1">:$num_samples
6358+
DefaultValuedAttr<SI32Attr, "1">:$num_samples,
6359+
DefaultValuedAttr<BoolAttr, "true">:$replacement
63596360
);
63606361
let has_logical_tensor_desc_infer_fn = 1;
63616362
let has_physical_tensor_desc_infer_fn = 1;

0 commit comments

Comments
 (0)