Skip to content

Commit ce19aa9

Browse files
hjheetoyxu
andauthored
Add aten::_fused_adam_, aten::_fused_adamw_ (#879)
- \_fused_adam_ - \_fused_adam_.tensor_lr - \_fused_adamw_ - \_fused_adamw_.tensor_lr --------- Co-authored-by: Yutao Xu <[email protected]>
1 parent 5e29831 commit ce19aa9

File tree

12 files changed

+1421
-0
lines changed

12 files changed

+1421
-0
lines changed

src/ATen/native/xpu/FusedAdam.cpp

Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
#include <ATen/native/ForeachUtils.h>
2+
3+
#ifndef AT_PER_OPERATOR_HEADERS
4+
#include <ATen/Functions.h>
5+
#include <ATen/NativeFunctions.h>
6+
#else
7+
#include <ATen/ops/_fused_adam.h>
8+
#include <ATen/ops/_fused_adam_native.h>
9+
#endif
10+
11+
#include <ATen/native/xpu/sycl/FusedAdamKernels.h>
12+
13+
namespace at {
14+
namespace native {
15+
16+
void _fused_adam_kernel_xpu_(
17+
at::TensorList params,
18+
at::TensorList grads,
19+
at::TensorList exp_avgs,
20+
at::TensorList exp_avg_sqs,
21+
at::TensorList max_exp_avg_sqs,
22+
at::TensorList state_steps,
23+
const double lr,
24+
const double beta1,
25+
const double beta2,
26+
const double weight_decay,
27+
const double eps,
28+
const bool amsgrad,
29+
const bool maximize,
30+
const std::optional<at::Tensor>& grad_scale,
31+
const std::optional<at::Tensor>& found_inf) {
32+
if (amsgrad) {
33+
TORCH_CHECK(
34+
at::native::check_fast_path_restrictions(
35+
{params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs}),
36+
"params, grads, exp_avgs, exp_avg_sqs, and max_exp_avg_sqs must have same dtype, device, and layout");
37+
xpu::fused_adam_amsgrad_kernel(
38+
params,
39+
grads,
40+
exp_avgs,
41+
exp_avg_sqs,
42+
max_exp_avg_sqs,
43+
state_steps,
44+
lr,
45+
beta1,
46+
beta2,
47+
weight_decay,
48+
eps,
49+
maximize,
50+
grad_scale,
51+
found_inf);
52+
} else {
53+
TORCH_CHECK(
54+
at::native::check_fast_path_restrictions(
55+
{params, grads, exp_avgs, exp_avg_sqs}),
56+
"params, grads, exp_avgs, and exp_avg_sqs must have same dtype, device, and layout");
57+
xpu::fused_adam_kernel(
58+
params,
59+
grads,
60+
exp_avgs,
61+
exp_avg_sqs,
62+
state_steps,
63+
lr,
64+
beta1,
65+
beta2,
66+
weight_decay,
67+
eps,
68+
maximize,
69+
grad_scale,
70+
found_inf);
71+
}
72+
}
73+
74+
// overload with tensor lr(single element tensor) input
75+
void _fused_adam_kernel_xpu_(
76+
at::TensorList params,
77+
at::TensorList grads,
78+
at::TensorList exp_avgs,
79+
at::TensorList exp_avg_sqs,
80+
at::TensorList max_exp_avg_sqs,
81+
at::TensorList state_steps,
82+
const Tensor& lr,
83+
const double beta1,
84+
const double beta2,
85+
const double weight_decay,
86+
const double eps,
87+
const bool amsgrad,
88+
const bool maximize,
89+
const c10::optional<at::Tensor>& grad_scale,
90+
const c10::optional<at::Tensor>& found_inf) {
91+
if (lr.is_cpu()) {
92+
_fused_adam_kernel_xpu_(
93+
params,
94+
grads,
95+
exp_avgs,
96+
exp_avg_sqs,
97+
max_exp_avg_sqs,
98+
state_steps,
99+
lr.item<double>(),
100+
beta1,
101+
beta2,
102+
weight_decay,
103+
eps,
104+
amsgrad,
105+
maximize,
106+
grad_scale,
107+
found_inf);
108+
return;
109+
}
110+
111+
// Manually check devices since we specify no device check in
112+
// native_functions.yaml
113+
Device param_device = params[0].device();
114+
if (grad_scale != std::nullopt) {
115+
TORCH_CHECK(
116+
grad_scale->device() == param_device,
117+
"grad_scale must be on the same GPU device as the params");
118+
}
119+
if (found_inf != std::nullopt) {
120+
TORCH_CHECK(
121+
found_inf->device() == param_device,
122+
"found_inf must be on the same GPU device as the params");
123+
}
124+
TORCH_CHECK(
125+
lr.device() == param_device,
126+
"lr must be on the same GPU device as the params");
127+
128+
if (amsgrad) {
129+
TORCH_CHECK(
130+
at::native::check_fast_path_restrictions(
131+
{params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs}),
132+
"params, grads, exp_avgs, exp_avg_sqs, and max_exp_avg_sqs must have same dtype, device, and layout");
133+
xpu::fused_adam_amsgrad_kernel(
134+
params,
135+
grads,
136+
exp_avgs,
137+
exp_avg_sqs,
138+
max_exp_avg_sqs,
139+
state_steps,
140+
lr,
141+
beta1,
142+
beta2,
143+
weight_decay,
144+
eps,
145+
maximize,
146+
grad_scale,
147+
found_inf);
148+
} else {
149+
TORCH_CHECK(
150+
at::native::check_fast_path_restrictions(
151+
{params, grads, exp_avgs, exp_avg_sqs}),
152+
"params, grads, exp_avgs, and exp_avg_sqs must have same dtype, device, and layout");
153+
xpu::fused_adam_kernel(
154+
params,
155+
grads,
156+
exp_avgs,
157+
exp_avg_sqs,
158+
state_steps,
159+
lr,
160+
beta1,
161+
beta2,
162+
weight_decay,
163+
eps,
164+
maximize,
165+
grad_scale,
166+
found_inf);
167+
}
168+
}
169+
170+
} // namespace native
171+
} // namespace at

src/ATen/native/xpu/FusedAdamW.cpp

Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
#include <ATen/native/ForeachUtils.h>
2+
3+
#ifndef AT_PER_OPERATOR_HEADERS
4+
#include <ATen/Functions.h>
5+
#include <ATen/NativeFunctions.h>
6+
#else
7+
#include <ATen/ops/_fused_adamw.h>
8+
#include <ATen/ops/_fused_adamw_native.h>
9+
#endif
10+
11+
#include <ATen/native/xpu/sycl/FusedAdamWKernels.h>
12+
13+
namespace at {
14+
namespace native {
15+
16+
void _fused_adamw_kernel_xpu_(
17+
at::TensorList params,
18+
at::TensorList grads,
19+
at::TensorList exp_avgs,
20+
at::TensorList exp_avg_sqs,
21+
at::TensorList max_exp_avg_sqs,
22+
at::TensorList state_steps,
23+
const double lr,
24+
const double beta1,
25+
const double beta2,
26+
const double weight_decay,
27+
const double eps,
28+
const bool amsgrad,
29+
const bool maximize,
30+
const c10::optional<at::Tensor>& grad_scale,
31+
const c10::optional<at::Tensor>& found_inf) {
32+
if (amsgrad) {
33+
TORCH_CHECK(
34+
at::native::check_fast_path_restrictions(
35+
{params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs}),
36+
"params, grads, exp_avgs, exp_avg_sqs, and max_exp_avg_sqs must have same dtype, device, and layout");
37+
xpu::fused_adamw_amsgrad_kernel(
38+
params,
39+
grads,
40+
exp_avgs,
41+
exp_avg_sqs,
42+
max_exp_avg_sqs,
43+
state_steps,
44+
lr,
45+
beta1,
46+
beta2,
47+
weight_decay,
48+
eps,
49+
maximize,
50+
grad_scale,
51+
found_inf);
52+
} else {
53+
TORCH_CHECK(
54+
at::native::check_fast_path_restrictions(
55+
{params, grads, exp_avgs, exp_avg_sqs}),
56+
"params, grads, exp_avgs, and exp_avg_sqs must have same dtype, device, and layout");
57+
xpu::fused_adamw_kernel(
58+
params,
59+
grads,
60+
exp_avgs,
61+
exp_avg_sqs,
62+
state_steps,
63+
lr,
64+
beta1,
65+
beta2,
66+
weight_decay,
67+
eps,
68+
maximize,
69+
grad_scale,
70+
found_inf);
71+
}
72+
}
73+
74+
// overload with tensor lr(single element tensor) input
75+
void _fused_adamw_kernel_xpu_(
76+
at::TensorList params,
77+
at::TensorList grads,
78+
at::TensorList exp_avgs,
79+
at::TensorList exp_avg_sqs,
80+
at::TensorList max_exp_avg_sqs,
81+
at::TensorList state_steps,
82+
const Tensor& lr,
83+
const double beta1,
84+
const double beta2,
85+
const double weight_decay,
86+
const double eps,
87+
const bool amsgrad,
88+
const bool maximize,
89+
const c10::optional<at::Tensor>& grad_scale,
90+
const c10::optional<at::Tensor>& found_inf) {
91+
if (lr.is_cpu()) {
92+
_fused_adamw_kernel_xpu_(
93+
params,
94+
grads,
95+
exp_avgs,
96+
exp_avg_sqs,
97+
max_exp_avg_sqs,
98+
state_steps,
99+
lr.item<double>(),
100+
beta1,
101+
beta2,
102+
weight_decay,
103+
eps,
104+
amsgrad,
105+
maximize,
106+
grad_scale,
107+
found_inf);
108+
return;
109+
}
110+
Device param_device = params[0].device();
111+
TORCH_CHECK(
112+
lr.device() == param_device,
113+
"lr must be on the same GPU device as the params");
114+
if (amsgrad) {
115+
TORCH_CHECK(
116+
at::native::check_fast_path_restrictions(
117+
{params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs}),
118+
"params, grads, exp_avgs, exp_avg_sqs, and max_exp_avg_sqs must have same dtype, device, and layout");
119+
xpu::fused_adamw_amsgrad_kernel(
120+
params,
121+
grads,
122+
exp_avgs,
123+
exp_avg_sqs,
124+
max_exp_avg_sqs,
125+
state_steps,
126+
lr,
127+
beta1,
128+
beta2,
129+
weight_decay,
130+
eps,
131+
maximize,
132+
grad_scale,
133+
found_inf);
134+
} else {
135+
TORCH_CHECK(
136+
at::native::check_fast_path_restrictions(
137+
{params, grads, exp_avgs, exp_avg_sqs}),
138+
"params, grads, exp_avgs, and exp_avg_sqs must have same dtype, device, and layout");
139+
xpu::fused_adamw_kernel(
140+
params,
141+
grads,
142+
exp_avgs,
143+
exp_avg_sqs,
144+
state_steps,
145+
lr,
146+
beta1,
147+
beta2,
148+
weight_decay,
149+
eps,
150+
maximize,
151+
grad_scale,
152+
found_inf);
153+
}
154+
}
155+
156+
} // namespace native
157+
} // namespace at

0 commit comments

Comments
 (0)