Skip to content

Commit 7d4f297

Browse files
fengyuan14Zhu, Yuhua
andauthored
bypass XeTLA implementation on platforms no XMX and add error handling for LLM (#3841)
* update has_xmx only for xpu device * qkv_gemm: bypass XeTLA implementation on platforms no XMX equiped * add torch.xpu.has_xmx() api and error handling for LLM * Fix flake8 --------- Signed-off-by: Zhu, Yuhua <[email protected]> Signed-off-by: Feng Yuan <[email protected]> Co-authored-by: Zhu, Yuhua <[email protected]>
1 parent b06fc38 commit 7d4f297

File tree

6 files changed

+67
-12
lines changed

6 files changed

+67
-12
lines changed

csrc/gpu/aten/operators/XeGemm.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#include <ATen/ATen.h>
33
#include <ATen/CPUApplyUtils.h>
44
#include <ATen/record_function.h>
5+
#include <runtime/Device.h>
56
#include <runtime/Utils.h>
67
#include <iostream>
78
#include "Linear.h"
@@ -465,7 +466,7 @@ static void mm_qkv_out(
465466
out1_valid && out2_valid && input_valid && weight_valid && bias_valid &&
466467
shape_valid;
467468

468-
if (use_xetla) {
469+
if (dpcppGetDeviceHasXMX() && use_xetla) {
469470
char str__[100];
470471
if (!has_bias) {
471472
sprintf(str__, "hgemm_qkv(%d, %d, %d)", m, n, k);

csrc/gpu/utils/Settings.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,11 @@ bool Settings::has_atomic64(int device_id) {
175175
return dpcppGetDeviceProperties(device_id)->support_atomic64;
176176
}
177177

178+
bool Settings::has_xmx(int device_id) {
179+
// whether XMX is supported in the specified platform.
180+
return dpcppGetDeviceHasXMX(device_id);
181+
}
182+
178183
int Settings::get_verbose_level() const {
179184
std::lock_guard<std::mutex> lock(s_mutex);
180185
return static_cast<int>(verbose_level);

csrc/gpu/utils/Settings.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ class IPEX_API Settings final {
4646
bool has_fp64_dtype(int device_id = -1);
4747
bool has_2d_block_array(int device_id = -1);
4848
bool has_atomic64(int device_id = -1);
49+
bool has_xmx(int device_id = -1);
4950

5051
static Settings& I(); // Singleton
5152

intel_extension_for_pytorch/csrc/xpu/Module.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -663,6 +663,8 @@ void init_xpu_module(pybind11::module& m) {
663663
return Settings::I().has_2d_block_array(device);
664664
});
665665

666+
m.def("_has_xmx", [](int device) { return Settings::I().has_xmx(device); });
667+
666668
m.def(
667669
"_get_verbose_level", []() { return Settings::I().get_verbose_level(); });
668670

intel_extension_for_pytorch/transformers/optimize.py

Lines changed: 52 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -586,17 +586,56 @@ def optimize_transformers(
586586
or re.search("falcon", model.config.architectures[0], re.IGNORECASE)
587587
or re.search("rw", model.config.architectures[0], re.IGNORECASE)
588588
) and device == "cpu"
589-
# bypass_ref_model = (re.search("Bloom", model.config.architectures[0], re.IGNORECASE)) or device == "xpu"
590-
xpu_supported_model = (
591-
re.search("GPTJ", model.config.architectures[0], re.IGNORECASE)
592-
or re.search("llama", model.config.architectures[0], re.IGNORECASE)
593-
or re.search("OPT", model.config.architectures[0], re.IGNORECASE)
594-
or re.search("Bloom", model.config.architectures[0], re.IGNORECASE)
595-
) and device == "xpu"
596-
if not (well_supported_model or xpu_supported_model):
589+
590+
# If the XPU platform does not have XMX, such as PVC1550vg, ipex.optimize_transformers is not supported.
591+
# If the XPU platform has XMX and 2D load instructions, such as PVC1100, PVC1100c, and PVC1550,
592+
# ipex.optimize_transformers supports GPT-J, Llama, OPT, Bloom, Falcon, QWen
593+
xpu_2d_load_supported_model = (
594+
(
595+
re.search("GPTJ", model.config.architectures[0], re.IGNORECASE)
596+
or re.search("llama", model.config.architectures[0], re.IGNORECASE)
597+
or re.search("OPT", model.config.architectures[0], re.IGNORECASE)
598+
or re.search("Bloom", model.config.architectures[0], re.IGNORECASE)
599+
or re.search("Falcon", model.config.architectures[0], re.IGNORECASE)
600+
or re.search("QWen", model.config.architectures[0], re.IGNORECASE)
601+
or re.search("Baichuan", model.config.architectures[0], re.IGNORECASE)
602+
)
603+
and device == "xpu"
604+
and ipex._C._has_2d_block_array(0)
605+
and ipex._C._has_xmx(0)
606+
)
607+
608+
# If the XPU platform has XMX but no 2D load instructions, such as ATS-M and ARC,
609+
# ipex.optimize_transformers supports GPT-J, Llama, QWen.
610+
xpu_non_2d_load_supported_model = (
611+
(
612+
re.search("GPTJ", model.config.architectures[0], re.IGNORECASE)
613+
or re.search("llama", model.config.architectures[0], re.IGNORECASE)
614+
or re.search("QWen", model.config.architectures[0], re.IGNORECASE)
615+
)
616+
and device == "xpu"
617+
and not ipex._C._has_2d_block_array(0)
618+
and ipex._C._has_xmx(0)
619+
)
620+
621+
if not (
622+
well_supported_model
623+
or xpu_2d_load_supported_model
624+
or xpu_non_2d_load_supported_model
625+
):
597626
warnings.warn(
598-
"optimize_transformers supports GPT-J/Llama/OPT/Bloom in XPU and Llama/GPT-J/GPT-Neox/Falcon/OPT"
599-
" in CPU, fallback to origin model"
627+
"The compatibility of ipex.optimize_transformers depends on the CPU/XPU platform "
628+
" and the transformer model. Here are the general rules: "
629+
" If the XPU platform does not have XMX, such as PVC1550vg, "
630+
" ipex.optimize_transformers is not supported. "
631+
" If the XPU platform has XMX and 2D load instructions, such as PVC1100, PVC1100c, and PVC1550,"
632+
" ipex.optimize_transformers supports GPT-J/Llama/OPT/Bloom/Falcon/QWen, "
633+
" and BasicTransformerBlock of diffusers. "
634+
" If the XPU platform has XMX but no 2D load instructions, such as ATS-M and ARC, "
635+
" ipex.optimize_transformers supports GPT-J/Llama/QWen, "
636+
" and BasicTransformerBlock of diffusers. "
637+
" If the platform is CPU, "
638+
" ipex.optimize_transformers supports Llama, GPT-J, GPT-Neox, Falcon, and OPT."
600639
)
601640
return model
602641

@@ -655,7 +694,9 @@ def optimize_transformers(
655694
xpu_woq = True
656695

657696
# model reference conversion
658-
if not (xpu_supported_model or xpu_woq):
697+
if not (
698+
xpu_2d_load_supported_model or xpu_non_2d_load_supported_model or xpu_woq
699+
):
659700
_model = model_convert_reference(_model)
660701

661702
# model quantization if needed

intel_extension_for_pytorch/xpu/utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,11 @@ def has_2d_block_array(device: int = -1) -> bool:
8888
return _C._has_2d_block_array(device)
8989

9090

91+
def has_xmx(device: int = -1) -> bool:
92+
r"""Returns a bool indicating if the platform supports xmx"""
93+
return _C._has_xmx(device)
94+
95+
9196
# Basic OnOff
9297
class OnOff:
9398
def __init__(self, checker, enable, disable):

0 commit comments

Comments
 (0)