Skip to content

Commit a6c044a

Browse files
atalmanbobrenjc93
andauthored
[cherry-pick] Unify torch.tensor and torch.ops.aten.scalar_tensor behavior (pytorch#158537) (pytorch#158655)
Unify torch.tensor and torch.ops.aten.scalar_tensor behavior (pytorch#158537) Fixes pytorch#158376 Pull Request resolved: pytorch#158537 Approved by: https://github.com/atalman Co-authored-by: bobrenjc93 <[email protected]>
1 parent 620ebd0 commit a6c044a

File tree

2 files changed

+54
-1
lines changed

2 files changed

+54
-1
lines changed

aten/src/ATen/ScalarOps.cpp

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,28 @@ namespace at {
88
namespace {
99
template <typename scalar_t>
1010
inline void fill_inplace(Tensor& self, const Scalar& value_scalar) {
11-
auto value = value_scalar.to<scalar_t>();
11+
scalar_t value{};
12+
13+
if constexpr (std::is_same_v<scalar_t, at::Half> ||
14+
std::is_same_v<scalar_t, at::BFloat16> ||
15+
std::is_same_v<scalar_t, at::Float8_e5m2> ||
16+
std::is_same_v<scalar_t, at::Float8_e5m2fnuz> ||
17+
std::is_same_v<scalar_t, at::Float8_e4m3fn> ||
18+
std::is_same_v<scalar_t, at::Float8_e4m3fnuz> ||
19+
std::is_same_v<scalar_t, at::Float8_e8m0fnu>) {
20+
// relaxed float cast: allow inf similar to the torch.tensor constructor
21+
//
22+
// without this, we had the following divergence:
23+
// torch.tensor(1123581321.0, dtype=torch.float16)
24+
// => tensor(inf, dtype=torch.float16)
25+
// torch.ops.aten.scalar_tensor.default(1123581321, dtype=torch.float16)
26+
// => RuntimeError: value cannot be converted to type at::Half without overflow
27+
28+
value = static_cast<scalar_t>(value_scalar.to<double>());
29+
} else {
30+
value = value_scalar.to<scalar_t>();
31+
}
32+
1233
scalar_t* dptr = static_cast<scalar_t*>(self.data_ptr());
1334
*dptr = value;
1435
}

test/dynamo/test_misc.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12975,6 +12975,38 @@ def f(actions, n_act, epsilon=0.1):
1297512975
y = torch.tensor(5)
1297612976
f(x, y)
1297712977

12978+
def test_dynamic_float_scalar_tensor_coersion(self):
12979+
# Minified version of https://github.com/pytorch/pytorch/issues/158376#issuecomment-3079591367
12980+
class Foo:
12981+
def __init__(self):
12982+
self.config = type(
12983+
"Config", (), {"pad_val": 1123581321.0, "tolerance": 1e-6}
12984+
)
12985+
12986+
@torch.compile(fullgraph=True)
12987+
def forward(self, input):
12988+
outputs = torch.where(
12989+
torch.abs(input - self.config.pad_val) < self.config.tolerance,
12990+
torch.tensor(
12991+
self.config.pad_val, dtype=input.dtype, device=input.device
12992+
),
12993+
torch.tensor(
12994+
self.config.pad_val + 1, dtype=input.dtype, device=input.device
12995+
),
12996+
)
12997+
return outputs
12998+
12999+
foo = Foo()
13000+
inputs = torch.randn(3, 4)
13001+
result = foo.forward(inputs)
13002+
13003+
original_pad_val = foo.config.pad_val
13004+
foo.config.pad_val += 1.0
13005+
result2 = foo.forward(inputs)
13006+
13007+
# Previously would crash with:
13008+
# RuntimeError: value cannot be converted to type at::Half without overflow
13009+
1297813010

1297913011
devices = ("cuda", "hpu")
1298013012
instantiate_device_type_tests(MiscTestsDevice, globals(), only_for=devices)

0 commit comments

Comments
 (0)