-
Notifications
You must be signed in to change notification settings - Fork 331
Description
Description
I have identified a potential stability issue when using TopPSampler and RandomSampler under a mixed precision policy (float16) with the TensorFlow backend.
In src/samplers/top_k_sampler.py, the inputs to random.categorical are explicitly cast to "float32" to prevent runtime errors, as noted in the source code comments:
tf does not support half precision multinomial sampling, so make sure we have full precision here.
However, this safeguard is missing in:
src/samplers/top_p_sampler.pysrc/samplers/random_sampler.py
Impact
Running generation tasks with these samplers in a mixed-precision environment may lead to runtime exceptions or numerical instability on backends that do not support float16 for multinomial sampling operations.
Proposed Fix
I have applied the same ops.cast(..., "float32") logic used in TopKSampler to both TopPSampler and RandomSampler to ensure consistency and stability across all sampling strategies.