Skip to content

Commit 269d9b0

Browse files
authored
Zhijxu/fix softmax cudnn bf16 (#21045)
if seq >2048, ort will fallback to cudnn version, while when dtype is bf16, ort will throw exception, this PR trying to fix it.
1 parent 5b5ce0b commit 269d9b0

File tree

3 files changed

+37
-1
lines changed

3 files changed

+37
-1
lines changed

onnxruntime/core/providers/cuda/cudnn_common.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,11 @@ cudnnDataType_t CudnnTensor::GetDataType<half>() {
174174

175175
template <>
176176
cudnnDataType_t CudnnTensor::GetDataType<BFloat16>() {
177+
#if defined(CUDNN_VERSION) && CUDNN_VERSION >= 8200
178+
return CUDNN_DATA_BFLOAT16;
179+
#else
177180
ORT_THROW("cuDNN doesn't support BFloat16.");
181+
#endif
178182
}
179183

180184
template <>

onnxruntime/test/common/random_generator.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ class RandomValueGenerator {
7070
// Random values generated are in the range [min, max).
7171
template <typename TFloat16>
7272
typename std::enable_if<
73-
std::is_same_v<TFloat16, MLFloat16>,
73+
std::is_same_v<TFloat16, MLFloat16> || std::is_same_v<TFloat16, BFloat16>,
7474
std::vector<TFloat16>>::type
7575
Uniform(gsl::span<const int64_t> dims, float min, float max) {
7676
std::vector<TFloat16> val(detail::SizeFromDims(dims));

orttraining/orttraining/test/python/orttraining_test_ortmodule_onnx_ops.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,38 @@ def test_onnx_ops(self):
146146
device = torch.device(device_name)
147147
self.gradient_correctness(name, device)
148148

149+
@unittest.skipIf(not torch.cuda.is_bf16_supported(), "Test requires CUDA and BF16 support")
150+
def test_softmax_bf16_large(self):
151+
if not torch.cuda.is_available():
152+
# only test bf16 on cuda
153+
return
154+
155+
class Model(torch.nn.Module):
156+
def __init__(self):
157+
super().__init__()
158+
159+
def forward(self, input):
160+
out = torch.softmax(input, dim=-1)
161+
return out
162+
163+
device = "cuda:0"
164+
input_shape = [2, 4096]
165+
# run torch to get the expected result
166+
data_torch = torch.randn(size=input_shape, device=device, dtype=torch.bfloat16) + 10
167+
data_torch.requires_grad = True
168+
torch_model = Model()
169+
torch_res = torch_model(input=data_torch)
170+
init_grad = torch.ones_like(torch_res)
171+
torch_res.backward(gradient=init_grad)
172+
# run ort
173+
ort_model = ORTModule(torch_model)
174+
data_ort = data_torch.detach().clone()
175+
data_ort.requires_grad = True
176+
ort_res = ort_model(input=data_ort)
177+
ort_res.backward(gradient=init_grad)
178+
# compara result
179+
torch.testing.assert_close(data_torch.grad, data_ort.grad, rtol=1e-5, atol=1e-4)
180+
149181

150182
if __name__ == "__main__":
151183
unittest.main()

0 commit comments

Comments
 (0)