@@ -53,101 +53,101 @@ class MatMulConfig(Config):
5353 def __init__ (
5454 self ,
5555 op : OpView ,
56- M_threads : int = 1 ,
57- K_threads : int = 1 ,
58- N_threads : int = 1 ,
59- M_block : int = 1 ,
60- K_block : int = 1 ,
61- N_block : int = 1 ,
62- innermostM_block : int = 1 ,
63- innermostK_block : int = 1 ,
64- innermostN_block : int = 1 ,
56+ MThreads : int = 1 ,
57+ KThreads : int = 1 ,
58+ NThreads : int = 1 ,
59+ MBlock : int = 1 ,
60+ KBlock : int = 1 ,
61+ NBlock : int = 1 ,
62+ innerMostMBlock : int = 1 ,
63+ innerMostKBlock : int = 1 ,
64+ innerMostNBlock : int = 1 ,
6565 ):
6666 # you can set the default value and candidates by info from matmul_op
67- self .M = op .inputs [0 ].type .shape [0 ]
68- self .K = op .inputs [0 ].type .shape [1 ]
69- self .N = op .inputs [1 ].type .shape [1 ]
67+ self .m = op .inputs [0 ].type .shape [0 ]
68+ self .k = op .inputs [0 ].type .shape [1 ]
69+ self .n = op .inputs [1 ].type .shape [1 ]
7070 # self.input_a_dtype = str(op.inputs[0].type.element_type)
7171 self .num_threads = int (os .environ .get ("OMP_NUM_THREADS" , 1 ))
72- self .M_threads = M_threads
73- self .K_threads = K_threads
74- self .N_threads = N_threads
75- self .M_block = M_block
76- self .K_block = K_block
77- self .N_block = N_block
78- self .innermostM_block = innermostM_block
79- self .innermostK_block = innermostK_block
80- self .innermostN_block = innermostN_block
72+ self .m_threads = MThreads
73+ self .k_threads = KThreads
74+ self .n_threads = NThreads
75+ self .m_block = MBlock
76+ self .k_block = KBlock
77+ self .n_block = NBlock
78+ self .innermost_m_block = innerMostMBlock
79+ self .innermost_k_block = innerMostKBlock
80+ self .innermost_n_block = innerMostNBlock
8181 super ().__init__ ()
8282
8383 def init_candidates (self ):
8484 default_blocks = [16 , 32 , 64 , 128 , 256 , 512 ]
8585 default_innermost_blocks = [16 , 32 ]
86- self .field_candidates ["M_threads " ] = find_factors (self .num_threads )
87- self .field_candidates ["K_threads " ] = find_factors (self .num_threads )
88- self .field_candidates ["N_threads " ] = find_factors (self .num_threads )
89- self .field_candidates ["M_block " ] = [
90- block for block in default_blocks if self .M >= block
86+ self .field_candidates ["m_threads " ] = find_factors (self .num_threads )
87+ self .field_candidates ["k_threads " ] = find_factors (self .num_threads )
88+ self .field_candidates ["n_threads " ] = find_factors (self .num_threads )
89+ self .field_candidates ["m_block " ] = [
90+ block for block in default_blocks if self .m >= block
9191 ]
92- self .field_candidates ["K_block " ] = [
93- block for block in default_blocks if self .K >= block
92+ self .field_candidates ["k_block " ] = [
93+ block for block in default_blocks if self .k >= block
9494 ]
95- self .field_candidates ["N_block " ] = [
96- block for block in default_blocks if self .N >= block
95+ self .field_candidates ["n_block " ] = [
96+ block for block in default_blocks if self .n >= block
9797 ]
98- self .field_candidates ["innermostM_block " ] = [
99- block for block in default_innermost_blocks if self .M >= block
98+ self .field_candidates ["innermost_m_block " ] = [
99+ block for block in default_innermost_blocks if self .m >= block
100100 ]
101- self .field_candidates ["innermostK_block " ] = [
102- block for block in default_innermost_blocks if self .K >= block
101+ self .field_candidates ["innermost_k_block " ] = [
102+ block for block in default_innermost_blocks if self .k >= block
103103 ]
104- self .field_candidates ["innermostN_block " ] = [
105- block for block in default_innermost_blocks if self .N >= block
104+ self .field_candidates ["innermost_n_block " ] = [
105+ block for block in default_innermost_blocks if self .n >= block
106106 ]
107107
108108 def init_constraints (self ):
109109 # example: using lambda to add constraints, adding constraints by the order of the fields
110- self .field_constraints ["M_threads " ] = None
111- self .field_constraints ["K_threads " ] = (
112- lambda MatMulConfig , K_threads : self .num_threads
113- % (MatMulConfig .M_threads * K_threads )
110+ self .field_constraints ["m_threads " ] = None
111+ self .field_constraints ["k_threads " ] = (
112+ lambda MatMulConfig , k_threads : self .num_threads
113+ % (MatMulConfig .m_threads * k_threads )
114114 == 0
115115 )
116- self .field_constraints ["N_threads " ] = (
117- lambda MatMulConfig , N_threads : self .num_threads
118- % (MatMulConfig .M_threads * MatMulConfig .K_threads * N_threads )
116+ self .field_constraints ["n_threads " ] = (
117+ lambda MatMulConfig , n_threads : self .num_threads
118+ % (MatMulConfig .m_threads * MatMulConfig .k_threads * n_threads )
119119 == 0
120120 )
121- self .field_constraints ["M_block " ] = None
122- self .field_constraints ["K_block " ] = None
123- self .field_constraints ["N_block " ] = None
124- self .field_constraints ["innermostM_block " ] = (
125- lambda MatMulConfig , innermostM_block : MatMulConfig .M_block
126- % innermostM_block
121+ self .field_constraints ["m_block " ] = None
122+ self .field_constraints ["k_block " ] = None
123+ self .field_constraints ["n_block " ] = None
124+ self .field_constraints ["innermost_m_block " ] = (
125+ lambda MatMulConfig , innermost_m_block : MatMulConfig .m_block
126+ % innermost_m_block
127127 == 0
128128 )
129- self .field_constraints ["innermostK_block " ] = (
130- lambda MatMulConfig , innermostK_block : MatMulConfig .K_block
131- % innermostK_block
129+ self .field_constraints ["innermost_k_block " ] = (
130+ lambda MatMulConfig , innermost_k_block : MatMulConfig .k_block
131+ % innermost_k_block
132132 == 0
133133 )
134- self .field_constraints ["innermostN_block " ] = (
135- lambda MatMulConfig , innermostN_block : MatMulConfig .N_block
136- % innermostN_block
134+ self .field_constraints ["innermost_n_block " ] = (
135+ lambda MatMulConfig , innermost_n_block : MatMulConfig .n_block
136+ % innermost_n_block
137137 == 0
138138 )
139139
140140 def attach_to_ir (self , op : OpView ):
141141 attr_to_field = {
142- "Mthreads " : self .M_threads ,
143- "Kthreads " : self .K_threads ,
144- "Nthreads " : self .N_threads ,
145- "MBlock" : self .M_block ,
146- "KBlock" : self .K_block ,
147- "NBlock" : self .N_block ,
148- "innermostMBlock " : self .innermostM_block ,
149- "innermostKBlock " : self .innermostK_block ,
150- "innermostNBlock " : self .innermostN_block ,
142+ "MThreads " : self .m_threads ,
143+ "KThreads " : self .k_threads ,
144+ "NThreads " : self .n_threads ,
145+ "MBlock" : self .m_block ,
146+ "KBlock" : self .k_block ,
147+ "NBlock" : self .n_block ,
148+ "innerMostMBlock " : self .innermost_m_block ,
149+ "innerMostKBlock " : self .innermost_k_block ,
150+ "innerMostNBlock " : self .innermost_n_block ,
151151 }
152152 for name , value in attr_to_field .items ():
153153 op .attributes [name ] = IntegerAttr .get (T .i32 (), value )
@@ -158,15 +158,15 @@ def __repr__(self) -> str:
158158 def __str__ (self ) -> str :
159159 obj_dict = {
160160 "MatMulConfig" : {
161- "M_threads " : self .M_threads ,
162- "K_threads " : self .K_threads ,
163- "N_threads " : self .N_threads ,
164- "M_block " : self .M_block ,
165- "K_block " : self .K_block ,
166- "N_block " : self .N_block ,
167- "innermostM_block " : self .innermostM_block ,
168- "innermostK_block " : self .innermostK_block ,
169- "innermostN_block " : self .innermostN_block ,
161+ "MThreads " : self .m_threads ,
162+ "KThreads " : self .k_threads ,
163+ "NThreads " : self .n_threads ,
164+ "MBlock " : self .m_block ,
165+ "KBlock " : self .k_block ,
166+ "NBlock " : self .n_block ,
167+ "innerMostMBlock " : self .innermost_m_block ,
168+ "innerMostKBlock " : self .innermost_k_block ,
169+ "innerMostNBlock " : self .innermost_n_block ,
170170 }
171171 }
172172 return json .dumps (obj_dict , indent = 4 )
0 commit comments