@@ -586,17 +586,56 @@ def optimize_transformers(
586586 or re .search ("falcon" , model .config .architectures [0 ], re .IGNORECASE )
587587 or re .search ("rw" , model .config .architectures [0 ], re .IGNORECASE )
588588 ) and device == "cpu"
589- # bypass_ref_model = (re.search("Bloom", model.config.architectures[0], re.IGNORECASE)) or device == "xpu"
590- xpu_supported_model = (
591- re .search ("GPTJ" , model .config .architectures [0 ], re .IGNORECASE )
592- or re .search ("llama" , model .config .architectures [0 ], re .IGNORECASE )
593- or re .search ("OPT" , model .config .architectures [0 ], re .IGNORECASE )
594- or re .search ("Bloom" , model .config .architectures [0 ], re .IGNORECASE )
595- ) and device == "xpu"
596- if not (well_supported_model or xpu_supported_model ):
589+
590+ # If the XPU platform does not have XMX, such as PVC1550vg, ipex.optimize_transformers is not supported.
591+ # If the XPU platform has XMX and 2D load instructions, such as PVC1100, PVC1100c, and PVC1550,
592+ # ipex.optimize_transformers supports GPT-J, Llama, OPT, Bloom, Falcon, QWen
593+ xpu_2d_load_supported_model = (
594+ (
595+ re .search ("GPTJ" , model .config .architectures [0 ], re .IGNORECASE )
596+ or re .search ("llama" , model .config .architectures [0 ], re .IGNORECASE )
597+ or re .search ("OPT" , model .config .architectures [0 ], re .IGNORECASE )
598+ or re .search ("Bloom" , model .config .architectures [0 ], re .IGNORECASE )
599+ or re .search ("Falcon" , model .config .architectures [0 ], re .IGNORECASE )
600+ or re .search ("QWen" , model .config .architectures [0 ], re .IGNORECASE )
601+ or re .search ("Baichuan" , model .config .architectures [0 ], re .IGNORECASE )
602+ )
603+ and device == "xpu"
604+ and ipex ._C ._has_2d_block_array (0 )
605+ and ipex ._C ._has_xmx (0 )
606+ )
607+
608+ # If the XPU platform has XMX but no 2D load instructions, such as ATS-M and ARC,
609+ # ipex.optimize_transformers supports GPT-J, Llama, QWen.
610+ xpu_non_2d_load_supported_model = (
611+ (
612+ re .search ("GPTJ" , model .config .architectures [0 ], re .IGNORECASE )
613+ or re .search ("llama" , model .config .architectures [0 ], re .IGNORECASE )
614+ or re .search ("QWen" , model .config .architectures [0 ], re .IGNORECASE )
615+ )
616+ and device == "xpu"
617+ and not ipex ._C ._has_2d_block_array (0 )
618+ and ipex ._C ._has_xmx (0 )
619+ )
620+
621+ if not (
622+ well_supported_model
623+ or xpu_2d_load_supported_model
624+ or xpu_non_2d_load_supported_model
625+ ):
597626 warnings .warn (
598- "optimize_transformers supports GPT-J/Llama/OPT/Bloom in XPU and Llama/GPT-J/GPT-Neox/Falcon/OPT"
599- " in CPU, fallback to origin model"
627+ "The compatibility of ipex.optimize_transformers depends on the CPU/XPU platform "
628+ " and the transformer model. Here are the general rules: "
629+ " If the XPU platform does not have XMX, such as PVC1550vg, "
630+ " ipex.optimize_transformers is not supported. "
631+ " If the XPU platform has XMX and 2D load instructions, such as PVC1100, PVC1100c, and PVC1550,"
632+ " ipex.optimize_transformers supports GPT-J/Llama/OPT/Bloom/Falcon/QWen, "
633+ " and BasicTransformerBlock of diffusers. "
634+ " If the XPU platform has XMX but no 2D load instructions, such as ATS-M and ARC, "
635+ " ipex.optimize_transformers supports GPT-J/Llama/QWen, "
636+ " and BasicTransformerBlock of diffusers. "
637+ " If the platform is CPU, "
638+ " ipex.optimize_transformers supports Llama, GPT-J, GPT-Neox, Falcon, and OPT."
600639 )
601640 return model
602641
@@ -655,7 +694,9 @@ def optimize_transformers(
655694 xpu_woq = True
656695
657696 # model reference conversion
658- if not (xpu_supported_model or xpu_woq ):
697+ if not (
698+ xpu_2d_load_supported_model or xpu_non_2d_load_supported_model or xpu_woq
699+ ):
659700 _model = model_convert_reference (_model )
660701
661702 # model quantization if needed
0 commit comments