Skip to content

Commit 0cd3091

Browse files
authored
Add aten::_standard_gamma (#1040)
- `_standard_gamma` - `_standard_gamma_grad`
1 parent 2ac2c45 commit 0cd3091

File tree

4 files changed

+54
-0
lines changed

4 files changed

+54
-0
lines changed

src/ATen/native/xpu/Distributions.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,14 @@ Tensor _s_binomial_xpu(
5757
return ret;
5858
}
5959

60+
Tensor _s_gamma_xpu(const Tensor& alpha, c10::optional<Generator> gen_) {
61+
auto gen = get_generator_or_default<at::XPUGeneratorImpl>(
62+
gen_, at::xpu::detail::getDefaultXPUGenerator());
63+
Tensor ret = at::empty(alpha.sizes(), alpha.options());
64+
xpu::launch_gamma_kernel(ret, alpha, gen);
65+
return ret;
66+
}
67+
6068
Tensor _sample_dirichlet_xpu(
6169
const Tensor& alpha,
6270
std::optional<Generator> generator) {
@@ -74,6 +82,17 @@ Tensor _sample_dirichlet_xpu(
7482
return ret;
7583
}
7684

85+
Tensor _standard_gamma_grad_xpu(const Tensor& self, const Tensor& output) {
86+
Tensor ret = at::empty(self.sizes(), self.options());
87+
TensorIterator iter = TensorIteratorConfig()
88+
.add_output(ret)
89+
.add_input(self)
90+
.add_input(output)
91+
.build();
92+
xpu::launch_standard_gamma_grad_kernel(iter);
93+
return ret;
94+
}
95+
7796
Tensor _dirichlet_grad_xpu(
7897
const Tensor& x,
7998
const Tensor& alpha,

src/ATen/native/xpu/sycl/Distributions.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,26 @@ void launch_gamma_kernel(
199199
[&] { gamma_kernel<scalar_t>(ret, alpha, rng_engine_inputs); });
200200
}
201201

202+
template <typename scalar_t, typename accscalar_t>
203+
struct StandardGammaGradKernelFunctor {
204+
scalar_t operator()(scalar_t self_val, scalar_t output_val) const {
205+
return standard_gamma_grad_one<scalar_t, accscalar_t>(self_val, output_val);
206+
}
207+
};
208+
209+
void launch_standard_gamma_grad_kernel(TensorIteratorBase& iter) {
210+
AT_DISPATCH_FLOATING_TYPES_AND2(
211+
at::ScalarType::Half,
212+
at::ScalarType::BFloat16,
213+
iter.input_dtype(),
214+
"_standard_gamma_grad_xpu",
215+
[&] {
216+
using accscalar_t = at::acc_type_device<scalar_t, kXPU>;
217+
StandardGammaGradKernelFunctor<scalar_t, accscalar_t> f;
218+
gpu_kernel(iter, f);
219+
});
220+
}
221+
202222
template <typename scalar_t>
203223
struct DirichletKernelFunctor {
204224
scalar_t operator()(scalar_t gamma, scalar_t gamma_sum) const {

src/ATen/native/xpu/sycl/Distributions.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ TORCH_XPU_API void launch_gamma_kernel(
1919
const Tensor& alpha,
2020
XPUGeneratorImpl* gen);
2121

22+
TORCH_XPU_API void launch_standard_gamma_grad_kernel(TensorIteratorBase& iter);
23+
2224
TORCH_XPU_API void launch_dirichlet_kernel(TensorIteratorBase& iter);
2325

2426
TORCH_XPU_API void launch_dirichlet_grad_kernel(TensorIteratorBase& iter);

yaml/native/native_functions.yaml

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5377,6 +5377,19 @@
53775377
tags: nondeterministic_seeded
53785378
autogen: binomial.out
53795379

5380+
- func: _standard_gamma_grad(Tensor self, Tensor output) -> Tensor
5381+
variants: function
5382+
dispatch:
5383+
XPU: _standard_gamma_grad_xpu
5384+
autogen: _standard_gamma_grad.out
5385+
5386+
- func: _standard_gamma(Tensor self, Generator? generator=None) -> Tensor
5387+
variants: function
5388+
dispatch:
5389+
XPU: _s_gamma_xpu
5390+
tags: nondeterministic_seeded
5391+
autogen: _standard_gamma.out
5392+
53805393
- func: _sample_dirichlet(Tensor self, Generator? generator=None) -> Tensor
53815394
tags: nondeterministic_seeded
53825395
variants: function

0 commit comments

Comments
 (0)