1313_CUTLASS_GROUPGEMM_CONFIG_MAP : Dict [Tuple [int , int , int , str ], Dict ] = {}
1414
1515
16- def _load_all_configs ():
16+ def load_all_configs ():
1717 """Load all cutlass groupgemm config files into the global map."""
1818 if _CUTLASS_GROUPGEMM_CONFIG_MAP :
1919 # Already loaded
@@ -23,38 +23,58 @@ def _load_all_configs():
2323
2424 # Load open source config directory
2525 opensource_dir = os .path .join (os .path .dirname (os .path .realpath (__file__ )), op_name )
26- if os .path .exists (opensource_dir ):
27- pattern = os .path .join (opensource_dir , "E=*-N=*-K=*-device_name=*.json" )
28- for config_file in glob .glob (pattern ):
29- filename = os .path .basename (config_file )
30- try :
31- # Parse filename: E={E}-N={N}-K={K}-device_name={device_name}.json
32- parts = filename .replace (".json" , "" ).split ("-" )
33- E = int (parts [0 ].split ("=" )[1 ])
34- N = int (parts [1 ].split ("=" )[1 ])
35- K = int (parts [2 ].split ("=" )[1 ])
36- device_name = parts [3 ].split ("=" )[1 ]
37-
38- # Load config
39- with open (config_file ) as f :
40- config_data = json .load (f )
41- # Convert string keys to int
42- config_data = {int (key ): val for key , val in config_data .items ()}
43-
44- # Store in global map
45- key = (E , N , K , device_name )
46- _CUTLASS_GROUPGEMM_CONFIG_MAP [key ] = config_data
47- logging .debug (f"Loaded config from { config_file } " )
48- except Exception as e :
49- logging .warning (f"Failed to load config from { config_file } : { e } " )
26+
27+ # Try to get internal source config directory
28+ # Collect all config directories to load
29+ config_dirs = [opensource_dir ]
30+ try :
31+ import internal_source .rtp_llm .models_py .kernels .cuda .fp8_kernel
32+
33+ internalsource_dir = os .path .join (
34+ os .path .dirname (
35+ os .path .realpath (
36+ internal_source .rtp_llm .models_py .kernels .cuda .fp8_kernel .__file__
37+ )
38+ ),
39+ op_name ,
40+ )
41+ config_dirs .append (internalsource_dir )
42+ except ImportError :
43+ logging .info ("internal_source not found" )
44+
45+ # Load configs from all directories
46+ for config_dir in config_dirs :
47+ if os .path .exists (config_dir ):
48+ logging .info (f"Loading configs from { config_dir } " )
49+ pattern = os .path .join (config_dir , "E=*-N=*-K=*-device_name=*.json" )
50+ for config_file in glob .glob (pattern ):
51+ filename = os .path .basename (config_file )
52+ try :
53+ # Parse filename: E={E}-N={N}-K={K}-device_name={device_name}.json
54+ parts = filename .replace (".json" , "" ).split ("-" )
55+ E = int (parts [0 ].split ("=" )[1 ])
56+ N = int (parts [1 ].split ("=" )[1 ])
57+ K = int (parts [2 ].split ("=" )[1 ])
58+ device_name = parts [3 ].split ("=" )[1 ]
59+
60+ # Load config
61+ with open (config_file ) as f :
62+ config_data = json .load (f )
63+ # Convert string keys to int
64+ config_data = {
65+ int (key ): val for key , val in config_data .items ()
66+ }
67+
68+ # Store in global map
69+ key = (E , N , K , device_name )
70+ _CUTLASS_GROUPGEMM_CONFIG_MAP [key ] = config_data
71+ logging .debug (f"Loaded config from { config_file } " )
72+ except Exception as e :
73+ logging .warning (f"Failed to load config from { config_file } : { e } " )
5074
5175 logging .info (
5276 f"Loaded { len (_CUTLASS_GROUPGEMM_CONFIG_MAP )} cutlass groupgemm configurations"
5377 )
54- try :
55- import internal_source .rtp_llm .utils .register_cutlass_configs
56- except :
57- logging .info ("internal_source not found" )
5878
5979
6080def register_cutlass_groupgemm_config (
@@ -90,9 +110,6 @@ def get_cutlass_groupgemm_best_config(E: int, N: int, K: int) -> Optional[Dict]:
90110 Configuration dictionary mapping batch sizes to tile configurations,
91111 or None if no configuration is found.
92112 """
93- # Load all configs if not already loaded
94- _load_all_configs ()
95-
96113 device_name = torch .cuda .get_device_name ().replace ("-" , "_" ).replace (" " , "_" )
97114 key = (E , N , K , device_name )
98115
0 commit comments