We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 61908a1 commit 9d076c9Copy full SHA for 9d076c9
src/cehrbert/models/hf_models/hf_cehrbert.py
@@ -171,6 +171,7 @@ class CehrBertPreTrainedModel(PreTrainedModel):
171
is_parallelizable = False
172
supports_gradient_checkpointing = True
173
_no_split_modules = ["BertLayer"]
174
+ _supports_sdpa = True
175
176
def _init_weights(self, module):
177
"""Initialize the weights."""
0 commit comments