Skip to content

Commit bc76ab1

Browse files
authored
Fix the presicion when weights are in bfloat16 for GroupNorm (#1587)
* fix the presicion when weights are in bfloat16 * use channles first format for full bf16 mode of GroupNorm * minor fix * revert use channles first format for full bf16 mode of GroupNorm * fix format * use first CL implementation for GroupNorm forward when parameters are in bf16 * remove using first CL implementation * eliminate redundant code * minor changes * add data type checks for GroupNorm backward
1 parent d8723df commit bc76ab1

File tree

3 files changed

+483
-249
lines changed

3 files changed

+483
-249
lines changed

csrc/cpu/aten/GroupNorm.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,13 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> native_group_norm_backward(
136136
TORCH_CHECK(
137137
X.suggest_memory_format() == dY.suggest_memory_format(),
138138
"Expected memory formats of X and dY are same.");
139+
TORCH_CHECK(
140+
X.scalar_type() == dY.scalar_type(),
141+
"Expected scalar type of X and dY are same.");
142+
bool mixed_type = at::native::is_mixed_type(dY, mean, rstd, gamma);
143+
if (mixed_type) {
144+
at::native::check_mixed_data_type(dY, mean, rstd, gamma);
145+
}
139146
at::Tensor dX;
140147
at::Tensor dgamma;
141148
at::Tensor dbeta;

0 commit comments

Comments
 (0)