Skip to content

Commit e9c57b9

Browse files
Use eager mask all the time for causal lm (#1424)
1 parent ecc4bb1 commit e9c57b9

File tree

1 file changed

+8
-16
lines changed

1 file changed

+8
-16
lines changed

optimum/exporters/openvino/model_patcher.py

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -356,14 +356,10 @@ def __enter__(self):
356356
# Although I'm not sure this is the right way to handle this, we are basically pretending that -65,504 is -inf
357357
ALL_MASK_ATTENTION_FUNCTIONS.register("eager", eager_mask_without_vmap)
358358

359-
# for non-stateful decoder models, we use eager mask without vmap for sdpa as well
360-
# to avoid a nan output issue in OpenVINO that only happens in case of non-stateful models
361-
if not getattr(self.real_config, "stateful", False):
362-
logger.warning(
363-
"Exporting a non-stateful decoder model currently results in a nan output in OpenVINO. "
364-
"There might be a performance impact due to the use of eager mask (floats) instead of sdpa mask (bools). "
365-
)
366-
ALL_MASK_ATTENTION_FUNCTIONS.register("sdpa", eager_mask_without_vmap)
359+
# for decoder models, we use eager mask without vmap for sdpa as well
360+
# to avoid a nan output issue in OpenVINO that only happens in case of:
361+
# non-stateful models on cpu and stateful models on npu
362+
ALL_MASK_ATTENTION_FUNCTIONS.register("sdpa", eager_mask_without_vmap)
367363

368364
def __exit__(self, exc_type, exc_value, traceback):
369365
super().__exit__(exc_type, exc_value, traceback)
@@ -4771,14 +4767,10 @@ def __enter__(self):
47714767
# Although I'm not sure this is the right way to handle this, we are basically pretending that -65,504 is -inf
47724768
ALL_MASK_ATTENTION_FUNCTIONS.register("eager", eager_mask_without_vmap)
47734769

4774-
# for non-stateful decoder models, we use eager mask without vmap for sdpa as well
4775-
# to avoid a nan output issue in OpenVINO that only happens in case of non-stateful models
4776-
if not getattr(self.real_config, "stateful", False):
4777-
logger.warning(
4778-
"Exporting a non-stateful decoder model currently results in a nan output in OpenVINO. "
4779-
"There might be a performance impact due to the use of eager mask (floats) instead of sdpa mask (bools). "
4780-
)
4781-
ALL_MASK_ATTENTION_FUNCTIONS.register("sdpa", eager_mask_without_vmap)
4770+
# for decoder models, we use eager mask without vmap for sdpa as well
4771+
# to avoid a nan output issue in OpenVINO that only happens in case of:
4772+
# non-stateful models on cpu and stateful models on npu
4773+
ALL_MASK_ATTENTION_FUNCTIONS.register("sdpa", eager_mask_without_vmap)
47824774

47834775
def __exit__(self, exc_type, exc_value, traceback):
47844776
super().__exit__(exc_type, exc_value, traceback)

0 commit comments

Comments
 (0)