1111# specific language governing permissions and limitations under the License. 
1212
1313import  logging 
14- from  contextlib  import  contextmanager 
1514from  typing  import  Callable , Optional 
1615
1716import  torch 
@@ -110,14 +109,13 @@ def export(
110109        example_input_ids  =  input_ids  if  input_ids  is  not   None  else  torch .tensor ([[1 ]], dtype = torch .long )
111110        example_cache_position  =  cache_position  if  cache_position  is  not   None  else  torch .tensor ([0 ], dtype = torch .long )
112111
113-         with  patch_mask_interface ():
114-             exported_program  =  torch .export .export (
115-                 self .model ,
116-                 args = (example_input_ids , example_cache_position ),
117-                 kwargs = {},
118-                 dynamic_shapes = dynamic_shapes ,
119-                 strict = strict  if  strict  is  not   None  else  True ,
120-             )
112+         exported_program  =  torch .export .export (
113+             self .model ,
114+             args = (example_input_ids , example_cache_position ),
115+             kwargs = {},
116+             dynamic_shapes = dynamic_shapes ,
117+             strict = strict  if  strict  is  not   None  else  True ,
118+         )
121119        return  exported_program 
122120
123121    @staticmethod  
@@ -456,24 +454,6 @@ def forward(
456454        return  outputs .logits 
457455
458456
459- @contextmanager  
460- def  patch_mask_interface ():
461-     """ 
462-     Context manager to locally use a simple dict instead of `AttentionMaskInterface`, as otherwise export will fail 
463-     with `strict=True` due to dynamo skip rules, i.e. `torch._dynamo.exc.Unsupported: 'inline in skipfiles: 
464-     Mapping.__contains__ | __contains__, skipped according trace_rules.lookup SKIP_DIRS'`. 
465-     Note that this seem to be an issue only for python<3.11. 
466-     """ 
467-     import  transformers 
468- 
469-     original  =  transformers .masking_utils .ALL_MASK_ATTENTION_FUNCTIONS 
470-     transformers .masking_utils .ALL_MASK_ATTENTION_FUNCTIONS  =  ALL_MASK_ATTENTION_FUNCTIONS ._global_mapping 
471-     try :
472-         yield 
473-     finally :
474-         transformers .masking_utils .ALL_MASK_ATTENTION_FUNCTIONS  =  original 
475- 
476- 
477457def  convert_and_export_with_cache (
478458    model : PreTrainedModel ,
479459    example_input_ids : Optional [torch .Tensor ] =  None ,
@@ -515,14 +495,13 @@ def convert_and_export_with_cache(
515495        )
516496
517497        if  is_torch_greater_or_equal ("2.6.0" ):
518-             with  patch_mask_interface ():
519-                 exported_program  =  torch .export .export (
520-                     TorchExportableModuleWithStaticCache (model ),
521-                     args = (example_input_ids , example_cache_position ),
522-                     kwargs = {},
523-                     dynamic_shapes = dynamic_shapes ,
524-                     strict = strict  if  strict  is  not   None  else  True ,
525-                 )
498+             exported_program  =  torch .export .export (
499+                 TorchExportableModuleWithStaticCache (model ),
500+                 args = (example_input_ids , example_cache_position ),
501+                 kwargs = {},
502+                 dynamic_shapes = dynamic_shapes ,
503+                 strict = strict  if  strict  is  not   None  else  True ,
504+             )
526505        else :
527506            if  dynamic_shapes  is  not   None :
528507                logging .warning (
@@ -534,14 +513,13 @@ def convert_and_export_with_cache(
534513            # 
535514            # Due to issue https://github.com/pytorch/pytorch/issues/128394, we need to switch to use an internal 
536515            # export API and pre_dispatch=False. Switch to use the public API once the issue is included in 2.5 release. 
537-             with  patch_mask_interface ():
538-                 exported_program  =  torch .export ._trace ._export (
539-                     TorchExportableModuleWithStaticCache (model ),
540-                     args = (example_input_ids ,),
541-                     kwargs = {"cache_position" : example_cache_position },
542-                     pre_dispatch = False ,
543-                     strict = True ,
544-                 )
516+             exported_program  =  torch .export ._trace ._export (
517+                 TorchExportableModuleWithStaticCache (model ),
518+                 args = (example_input_ids ,),
519+                 kwargs = {"cache_position" : example_cache_position },
520+                 pre_dispatch = False ,
521+                 strict = True ,
522+             )
545523        return  exported_program 
546524
547525
@@ -634,10 +612,9 @@ def _export_encoder(self, encoder_input_ids):
634612
635613        # Export the encoder 
636614        with  torch .no_grad ():
637-             with  patch_mask_interface ():
638-                 exported_encoder  =  torch .export .export (
639-                     wrapped_encoder , (encoder_input_ids ,), dynamic_shapes = {"input_ids" : {1 : seq_len_dim }}, strict = True 
640-                 )
615+             exported_encoder  =  torch .export .export (
616+                 wrapped_encoder , (encoder_input_ids ,), dynamic_shapes = {"input_ids" : {1 : seq_len_dim }}, strict = True 
617+             )
641618
642619        return  exported_encoder 
643620
@@ -657,17 +634,16 @@ def _export_decoder(self, decoder_input_ids, encoder_hidden_states, cache_positi
657634
658635        # Export the decoder 
659636        with  torch .no_grad ():
660-             with  patch_mask_interface ():
661-                 exported_decoder  =  torch .export .export (
662-                     wrapped_decoder ,
663-                     (decoder_input_ids , encoder_hidden_states , cache_position ),
664-                     dynamic_shapes = {
665-                         "decoder_input_ids" : None ,
666-                         "encoder_hidden_states" : {1 : encoder_seq_len_dim },
667-                         "cache_position" : None ,
668-                     },
669-                     strict = True ,
670-                 )
637+             exported_decoder  =  torch .export .export (
638+                 wrapped_decoder ,
639+                 (decoder_input_ids , encoder_hidden_states , cache_position ),
640+                 dynamic_shapes = {
641+                     "decoder_input_ids" : None ,
642+                     "encoder_hidden_states" : {1 : encoder_seq_len_dim },
643+                     "cache_position" : None ,
644+                 },
645+                 strict = True ,
646+             )
671647
672648        return  exported_decoder 
673649
0 commit comments