File tree Expand file tree Collapse file tree 4 files changed +54
-0
lines changed Expand file tree Collapse file tree 4 files changed +54
-0
lines changed Original file line number Diff line number Diff line change @@ -57,6 +57,14 @@ Tensor _s_binomial_xpu(
5757 return ret;
5858}
5959
60+ Tensor _s_gamma_xpu (const Tensor& alpha, c10::optional<Generator> gen_) {
61+ auto gen = get_generator_or_default<at::XPUGeneratorImpl>(
62+ gen_, at::xpu::detail::getDefaultXPUGenerator ());
63+ Tensor ret = at::empty (alpha.sizes (), alpha.options ());
64+ xpu::launch_gamma_kernel (ret, alpha, gen);
65+ return ret;
66+ }
67+
6068Tensor _sample_dirichlet_xpu (
6169 const Tensor& alpha,
6270 std::optional<Generator> generator) {
@@ -74,6 +82,17 @@ Tensor _sample_dirichlet_xpu(
7482 return ret;
7583}
7684
85+ Tensor _standard_gamma_grad_xpu (const Tensor& self, const Tensor& output) {
86+ Tensor ret = at::empty (self.sizes (), self.options ());
87+ TensorIterator iter = TensorIteratorConfig ()
88+ .add_output (ret)
89+ .add_input (self)
90+ .add_input (output)
91+ .build ();
92+ xpu::launch_standard_gamma_grad_kernel (iter);
93+ return ret;
94+ }
95+
7796Tensor _dirichlet_grad_xpu (
7897 const Tensor& x,
7998 const Tensor& alpha,
Original file line number Diff line number Diff line change @@ -199,6 +199,26 @@ void launch_gamma_kernel(
199199 [&] { gamma_kernel<scalar_t >(ret, alpha, rng_engine_inputs); });
200200}
201201
202+ template <typename scalar_t , typename accscalar_t >
203+ struct StandardGammaGradKernelFunctor {
204+ scalar_t operator ()(scalar_t self_val, scalar_t output_val) const {
205+ return standard_gamma_grad_one<scalar_t , accscalar_t >(self_val, output_val);
206+ }
207+ };
208+
209+ void launch_standard_gamma_grad_kernel (TensorIteratorBase& iter) {
210+ AT_DISPATCH_FLOATING_TYPES_AND2 (
211+ at::ScalarType::Half,
212+ at::ScalarType::BFloat16,
213+ iter.input_dtype (),
214+ " _standard_gamma_grad_xpu" ,
215+ [&] {
216+ using accscalar_t = at::acc_type_device<scalar_t , kXPU >;
217+ StandardGammaGradKernelFunctor<scalar_t , accscalar_t > f;
218+ gpu_kernel (iter, f);
219+ });
220+ }
221+
202222template <typename scalar_t >
203223struct DirichletKernelFunctor {
204224 scalar_t operator ()(scalar_t gamma, scalar_t gamma_sum) const {
Original file line number Diff line number Diff line change @@ -19,6 +19,8 @@ TORCH_XPU_API void launch_gamma_kernel(
1919 const Tensor& alpha,
2020 XPUGeneratorImpl* gen);
2121
22+ TORCH_XPU_API void launch_standard_gamma_grad_kernel (TensorIteratorBase& iter);
23+
2224TORCH_XPU_API void launch_dirichlet_kernel (TensorIteratorBase& iter);
2325
2426TORCH_XPU_API void launch_dirichlet_grad_kernel (TensorIteratorBase& iter);
Original file line number Diff line number Diff line change 53775377 tags : nondeterministic_seeded
53785378 autogen : binomial.out
53795379
5380+ - func : _standard_gamma_grad(Tensor self, Tensor output) -> Tensor
5381+ variants : function
5382+ dispatch :
5383+ XPU : _standard_gamma_grad_xpu
5384+ autogen : _standard_gamma_grad.out
5385+
5386+ - func : _standard_gamma(Tensor self, Generator? generator=None) -> Tensor
5387+ variants : function
5388+ dispatch :
5389+ XPU : _s_gamma_xpu
5390+ tags : nondeterministic_seeded
5391+ autogen : _standard_gamma.out
5392+
53805393- func : _sample_dirichlet(Tensor self, Generator? generator=None) -> Tensor
53815394 tags : nondeterministic_seeded
53825395 variants : function
You can’t perform that action at this time.
0 commit comments