Skip to content

Commit 77b9399

Browse files
KarhouTampytorchmergebot
authored andcommitted
[random] Add generator arg to rand*_like APIs (pytorch#166160)
Fixes pytorch#165865 ## What this PR does? - [x] Add `generator` arg to `rand*_like` APIs (`rand_like()`, `randn_like()`, `randint_like()`). - [x] Add unit tests for `rand*_like` APIs - [x] Add corresponding arg docs - [x] Refactor `rand*_like()` codes in `TensorFactories.cpp` - [x] Add corresponding and former missed items in `VmapModeRegistrations.cpp` ## Example (using `rand_like()`) ```python gen0 = torch.Generator() gen1 = torch.Generator() gen2 = torch.Generator() gen0.manual_seed(42) gen1.manual_seed(42) gen2.manual_seed(2025) tensor = torch.empty(10) t0 = torch.rand_like(tensor, generator=gen0) t1 = torch.rand_like(tensor, generator=gen1) t2 = torch.rand_like(tensor, generator=gen2) assert t0 == t1 assert t2 != t0 assert t2 != t1 ``` Pull Request resolved: pytorch#166160 Approved by: https://github.com/cyyever, https://github.com/albanD
1 parent 83cd626 commit 77b9399

File tree

8 files changed

+340
-24
lines changed

8 files changed

+340
-24
lines changed

aten/src/ATen/VmapModeRegistrations.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,10 +72,16 @@ TORCH_LIBRARY_IMPL(aten, VmapMode, m) {
7272
m.impl("random_", unsupportedRandomOp_<Tensor&, std::optional<Generator>>);
7373

7474
m.impl("rand_like", unsupportedRandomOp<const Tensor&, TENSOROPTIONS, std::optional<MemoryFormat>>);
75+
m.impl("rand_like.generator", unsupportedRandomOp<const Tensor&, std::optional<Generator>, TENSOROPTIONS, std::optional<MemoryFormat>>);
7576
m.impl("randn_like", unsupportedRandomOp<const Tensor&, TENSOROPTIONS, std::optional<MemoryFormat>>);
77+
m.impl("randn_like.generator", unsupportedRandomOp<const Tensor&, std::optional<Generator>, TENSOROPTIONS, std::optional<MemoryFormat>>);
7678

7779
m.impl("randint_like", unsupportedRandomOp<const Tensor&, int64_t, TENSOROPTIONS, std::optional<MemoryFormat>>);
80+
m.impl("randint_like.Tensor", unsupportedRandomOp<const Tensor&, const Tensor&, TENSOROPTIONS, std::optional<MemoryFormat>>);
7881
m.impl("randint_like.low_dtype", unsupportedRandomOp<const Tensor&, int64_t, int64_t, TENSOROPTIONS, std::optional<MemoryFormat>>);
82+
m.impl("randint_like.generator", unsupportedRandomOp<const Tensor&, int64_t, std::optional<Generator>, TENSOROPTIONS, std::optional<MemoryFormat>>);
83+
m.impl("randint_like.Tensor_generator", unsupportedRandomOp<const Tensor&, const Tensor&, std::optional<Generator>, TENSOROPTIONS, std::optional<MemoryFormat>>);
84+
m.impl("randint_like.low_generator_dtype", unsupportedRandomOp<const Tensor&, int64_t, int64_t, std::optional<Generator>, TENSOROPTIONS, std::optional<MemoryFormat>>);
7985

8086
m.impl("rand", unsupportedRandomOp<IntArrayRef, TENSOROPTIONS>);
8187
m.impl("rand.generator", unsupportedRandomOp<IntArrayRef, std::optional<Generator>, TENSOROPTIONS>);

aten/src/ATen/native/TensorFactories.cpp

Lines changed: 124 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include <ATen/SparseCsrTensorUtils.h>
1212
#include <ATen/TensorOperators.h>
1313
#include <ATen/TracerMode.h>
14+
#include <ATen/core/Generator.h>
1415
#include <ATen/core/Tensor.h>
1516
#include <ATen/native/UnaryOps.h>
1617
#include <c10/core/ScalarType.h>
@@ -1089,6 +1090,7 @@ Tensor& rand_out(
10891090

10901091
Tensor rand_like(
10911092
const Tensor& self,
1093+
std::optional<Generator> generator,
10921094
std::optional<ScalarType> dtype,
10931095
std::optional<Layout> layout,
10941096
std::optional<Device> device,
@@ -1100,7 +1102,24 @@ Tensor rand_like(
11001102
pin_memory);
11011103

11021104
auto result = at::empty_like(self, options, optional_memory_format);
1103-
return result.uniform_(0, 1, std::nullopt);
1105+
return result.uniform_(0, 1, std::move(generator));
1106+
}
1107+
1108+
Tensor rand_like(
1109+
const Tensor& self,
1110+
std::optional<ScalarType> dtype,
1111+
std::optional<Layout> layout,
1112+
std::optional<Device> device,
1113+
std::optional<bool> pin_memory,
1114+
std::optional<c10::MemoryFormat> optional_memory_format) {
1115+
return native::rand_like(
1116+
self,
1117+
static_cast<std::optional<Generator>>(std::nullopt),
1118+
dtype,
1119+
layout,
1120+
device,
1121+
pin_memory,
1122+
optional_memory_format);
11041123
}
11051124

11061125
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ randint ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -1197,7 +1216,9 @@ Tensor& randint_out(
11971216

11981217
Tensor randint_like(
11991218
const Tensor& self,
1219+
int64_t low,
12001220
int64_t high,
1221+
std::optional<Generator> generator,
12011222
std::optional<ScalarType> dtype,
12021223
std::optional<Layout> layout,
12031224
std::optional<Device> device,
@@ -1209,7 +1230,71 @@ Tensor randint_like(
12091230
pin_memory);
12101231

12111232
auto result = at::empty_like(self, options, optional_memory_format);
1212-
return result.random_(0, high, std::nullopt);
1233+
return result.random_(low, high, std::move(generator));
1234+
}
1235+
1236+
Tensor randint_like(
1237+
const Tensor& self,
1238+
int64_t low,
1239+
int64_t high,
1240+
std::optional<ScalarType> dtype,
1241+
std::optional<Layout> layout,
1242+
std::optional<Device> device,
1243+
std::optional<bool> pin_memory,
1244+
std::optional<c10::MemoryFormat> optional_memory_format) {
1245+
return native::randint_like(
1246+
self,
1247+
low,
1248+
high,
1249+
static_cast<std::optional<Generator>>(std::nullopt),
1250+
dtype,
1251+
layout,
1252+
device,
1253+
pin_memory,
1254+
optional_memory_format);
1255+
}
1256+
1257+
Tensor randint_like(
1258+
const Tensor& self,
1259+
int64_t high,
1260+
std::optional<ScalarType> dtype,
1261+
std::optional<Layout> layout,
1262+
std::optional<Device> device,
1263+
std::optional<bool> pin_memory,
1264+
std::optional<c10::MemoryFormat> optional_memory_format) {
1265+
// See [Note: hacky wrapper removal for TensorOptions]
1266+
return native::randint_like(
1267+
self,
1268+
0,
1269+
high,
1270+
static_cast<std::optional<Generator>>(std::nullopt),
1271+
dtype,
1272+
layout,
1273+
device,
1274+
pin_memory,
1275+
optional_memory_format);
1276+
}
1277+
1278+
Tensor randint_like(
1279+
const Tensor& self,
1280+
int64_t high,
1281+
std::optional<Generator> generator,
1282+
std::optional<ScalarType> dtype,
1283+
std::optional<Layout> layout,
1284+
std::optional<Device> device,
1285+
std::optional<bool> pin_memory,
1286+
std::optional<c10::MemoryFormat> optional_memory_format) {
1287+
// See [Note: hacky wrapper removal for TensorOptions]
1288+
return native::randint_like(
1289+
self,
1290+
0,
1291+
high,
1292+
generator,
1293+
dtype,
1294+
layout,
1295+
device,
1296+
pin_memory,
1297+
optional_memory_format);
12131298
}
12141299

12151300
Tensor randint_like(
@@ -1226,7 +1311,9 @@ Tensor randint_like(
12261311
int64_t high_scalar = high.item<int64_t>();
12271312
return at::native::randint_like(
12281313
self,
1314+
0,
12291315
high_scalar,
1316+
static_cast<std::optional<Generator>>(std::nullopt),
12301317
dtype,
12311318
layout,
12321319
device,
@@ -1236,20 +1323,27 @@ Tensor randint_like(
12361323

12371324
Tensor randint_like(
12381325
const Tensor& self,
1239-
int64_t low,
1240-
int64_t high,
1326+
const Tensor& high,
1327+
std::optional<Generator> generator,
12411328
std::optional<ScalarType> dtype,
12421329
std::optional<Layout> layout,
12431330
std::optional<Device> device,
12441331
std::optional<bool> pin_memory,
12451332
std::optional<c10::MemoryFormat> optional_memory_format) {
1246-
// See [Note: hacky wrapper removal for TensorOptions]
1247-
TensorOptions options =
1248-
TensorOptions().dtype(dtype).layout(layout).device(device).pinned_memory(
1249-
pin_memory);
1250-
1251-
auto result = at::empty_like(self, options, optional_memory_format);
1252-
return result.random_(low, high, std::nullopt);
1333+
TORCH_CHECK(
1334+
high.numel() == 1 && high.ndimension() == 0 && high.device().is_cpu(),
1335+
"high must be a scalar tensor and on CPU");
1336+
int64_t high_scalar = high.item<int64_t>();
1337+
return at::native::randint_like(
1338+
self,
1339+
0,
1340+
high_scalar,
1341+
generator,
1342+
dtype,
1343+
layout,
1344+
device,
1345+
pin_memory,
1346+
optional_memory_format);
12531347
}
12541348

12551349
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ randn ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -1327,6 +1421,7 @@ Tensor& normal_out(
13271421

13281422
Tensor randn_like(
13291423
const Tensor& self,
1424+
std::optional<Generator> generator,
13301425
std::optional<ScalarType> dtype,
13311426
std::optional<Layout> layout,
13321427
std::optional<Device> device,
@@ -1338,7 +1433,24 @@ Tensor randn_like(
13381433
pin_memory);
13391434

13401435
auto result = at::empty_like(self, options, optional_memory_format);
1341-
return result.normal_(0, 1, std::nullopt);
1436+
return result.normal_(0, 1, std::move(generator));
1437+
}
1438+
1439+
Tensor randn_like(
1440+
const Tensor& self,
1441+
std::optional<ScalarType> dtype,
1442+
std::optional<Layout> layout,
1443+
std::optional<Device> device,
1444+
std::optional<bool> pin_memory,
1445+
std::optional<c10::MemoryFormat> optional_memory_format) {
1446+
return native::randn_like(
1447+
self,
1448+
static_cast<std::optional<Generator>>(std::nullopt),
1449+
dtype,
1450+
layout,
1451+
device,
1452+
pin_memory,
1453+
optional_memory_format);
13421454
}
13431455

13441456
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ randperm ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

aten/src/ATen/native/native_functions.yaml

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4800,6 +4800,12 @@
48004800
CompositeExplicitAutograd: rand_like
48014801
autogen: rand_like.out
48024802

4803+
- func: rand_like.generator(Tensor self, *, Generator? generator, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
4804+
tags: nondeterministic_seeded
4805+
dispatch:
4806+
CompositeExplicitAutograd: rand_like
4807+
autogen: rand_like.generator_out
4808+
48034809
- func: randint(SymInt high, SymInt[] size, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
48044810
tags: nondeterministic_seeded
48054811
dispatch:
@@ -4848,6 +4854,14 @@
48484854
CompositeExplicitAutograd: randint_like
48494855
autogen: randint_like.out
48504856

4857+
- func: randint_like.generator(Tensor self, SymInt high, *, Generator? generator, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
4858+
tags: nondeterministic_seeded
4859+
dispatch:
4860+
# NB: Although this composite mutates on the inside, it is
4861+
# non-differentiable so NonFunctional doesn't apply
4862+
CompositeExplicitAutograd: randint_like
4863+
autogen: randint_like.generator_out
4864+
48514865
- func: randint_like.Tensor(Tensor self, Tensor high, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
48524866
tags: nondeterministic_seeded
48534867
dispatch:
@@ -4856,6 +4870,14 @@
48564870
CompositeExplicitAutograd: randint_like
48574871
autogen: randint_like.Tensor_out
48584872

4873+
- func: randint_like.Tensor_generator(Tensor self, Tensor high, *, Generator? generator, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
4874+
tags: nondeterministic_seeded
4875+
dispatch:
4876+
# NB: Although this composite mutates on the inside, it is
4877+
# non-differentiable so NonFunctional doesn't apply
4878+
CompositeExplicitAutograd: randint_like
4879+
autogen: randint_like.Tensor_generator_out
4880+
48594881
- func: randint_like.low_dtype(Tensor self, SymInt low, SymInt high, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
48604882
tags: nondeterministic_seeded
48614883
dispatch:
@@ -4864,6 +4886,14 @@
48644886
CompositeExplicitAutograd: randint_like
48654887
autogen: randint_like.low_dtype_out
48664888

4889+
- func: randint_like.low_generator_dtype(Tensor self, SymInt low, SymInt high, *, Generator? generator, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
4890+
tags: nondeterministic_seeded
4891+
dispatch:
4892+
# NB: Although this composite mutates on the inside, it is
4893+
# non-differentiable so NonFunctional doesn't apply
4894+
CompositeExplicitAutograd: randint_like
4895+
autogen: randint_like.low_generator_dtype_out
4896+
48674897
- func: randn(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
48684898
tags: [core, nondeterministic_seeded]
48694899
dispatch:
@@ -4904,6 +4934,14 @@
49044934
CompositeExplicitAutograd, CompositeImplicitAutogradNestedTensor: randn_like
49054935
autogen: randn_like.out
49064936

4937+
- func: randn_like.generator(Tensor self, *, Generator? generator, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
4938+
tags: nondeterministic_seeded
4939+
dispatch:
4940+
# NB: Although this composite mutates on the inside, it is
4941+
# non-differentiable so NonFunctional doesn't apply
4942+
CompositeExplicitAutograd, CompositeImplicitAutogradNestedTensor: randn_like
4943+
autogen: randn_like.generator_out
4944+
49074945
- func: randperm(SymInt n, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
49084946
tags: [core, nondeterministic_seeded]
49094947
dispatch:

test/expect/HasDecompTest.test_has_decomposition.expect

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1089,6 +1089,8 @@ aten::rand.names
10891089
aten::rand.names_out
10901090
aten::rand.out
10911091
aten::rand_like
1092+
aten::rand_like.generator
1093+
aten::rand_like.generator_out
10921094
aten::rand_like.out
10931095
aten::randint
10941096
aten::randint.generator
@@ -1100,16 +1102,24 @@ aten::randint.low_out
11001102
aten::randint.out
11011103
aten::randint_like
11021104
aten::randint_like.Tensor
1105+
aten::randint_like.Tensor_generator
1106+
aten::randint_like.Tensor_generator_out
11031107
aten::randint_like.Tensor_out
1108+
aten::randint_like.generator
1109+
aten::randint_like.generator_out
11041110
aten::randint_like.low_dtype
11051111
aten::randint_like.low_dtype_out
1112+
aten::randint_like.low_generator_dtype
1113+
aten::randint_like.low_generator_dtype_out
11061114
aten::randint_like.out
11071115
aten::randn.generator
11081116
aten::randn.generator_with_names
11091117
aten::randn.generator_with_names_out
11101118
aten::randn.names
11111119
aten::randn.names_out
11121120
aten::randn_like
1121+
aten::randn_like.generator
1122+
aten::randn_like.generator_out
11131123
aten::randn_like.out
11141124
aten::random
11151125
aten::random.from

0 commit comments

Comments
 (0)