Skip to content

[Bug] Missing float32 cast in TopP and Random samplers causes crashes in Mixed Precision #2584

@AlejandroPG06

Description

@AlejandroPG06

Description

I have identified a potential stability issue when using TopPSampler and RandomSampler under a mixed precision policy (float16) with the TensorFlow backend.

In src/samplers/top_k_sampler.py, the inputs to random.categorical are explicitly cast to "float32" to prevent runtime errors, as noted in the source code comments:

tf does not support half precision multinomial sampling, so make sure we have full precision here.

However, this safeguard is missing in:

  1. src/samplers/top_p_sampler.py
  2. src/samplers/random_sampler.py

Impact

Running generation tasks with these samplers in a mixed-precision environment may lead to runtime exceptions or numerical instability on backends that do not support float16 for multinomial sampling operations.

Proposed Fix

I have applied the same ops.cast(..., "float32") logic used in TopKSampler to both TopPSampler and RandomSampler to ensure consistency and stability across all sampling strategies.

Metadata

Metadata

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions