Skip to content

Commit c3a0eaa

Browse files
authored
[cherry-pick]Support fixed seed in Python for test (#36065) (#36094)
When users use gumbel_softmax, they can use paddle.seed() in python for fixed seed.
1 parent bc13ab9 commit c3a0eaa

File tree

2 files changed

+2
-3
lines changed

2 files changed

+2
-3
lines changed

paddle/fluid/operators/gumbel_softmax_op.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,6 @@ struct GumbleNoiseGenerator<platform::CUDADeviceContext, T> {
130130
T* random_data =
131131
random_tensor.mutable_data<T>({size}, platform::CUDAPlace());
132132
thrust::counting_iterator<unsigned int> index_sequence_begin(0);
133-
const unsigned int seed = std::random_device()();
134133

135134
// generate gumbel noise
136135
int device_id =
@@ -144,6 +143,7 @@ struct GumbleNoiseGenerator<platform::CUDADeviceContext, T> {
144143
thrust::device_ptr<T>(random_data),
145144
UniformCUDAGenerator<T>(0.00001, 1, seed_offset.first, gen_offset));
146145
} else {
146+
const unsigned int seed = std::random_device()();
147147
thrust::transform(index_sequence_begin, index_sequence_begin + size,
148148
thrust::device_ptr<T>(random_data),
149149
UniformCUDAGenerator<T>(0.00001, 1, seed));

paddle/fluid/operators/gumbel_softmax_op.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,8 +86,7 @@ struct GumbleNoiseGenerator<platform::CPUDeviceContext, T> {
8686
// generate uniform random number
8787
const int size = size_to_axis * size_from_axis;
8888
std::uniform_real_distribution<T> dist(0.00001, 1);
89-
const int seed = std::random_device()();
90-
auto engine = paddle::framework::GetCPURandomEngine(seed);
89+
auto engine = paddle::framework::GetCPURandomEngine(0);
9190
Tensor random_tensor;
9291
auto* random_data =
9392
random_tensor.mutable_data<T>({size}, platform::CPUPlace());

0 commit comments

Comments
 (0)