|
1 | 1 | from kirin.ir import types |
2 | 2 | from kirin.interp import Frame, MethodTable, StatementResult, impl |
| 3 | +from kirin.analysis.typeinfer import TypeInference |
3 | 4 |
|
4 | 5 | from . import _stmts as py |
5 | 6 | from .dialect import dialect |
@@ -175,12 +176,12 @@ def mat_mult(self, interp, frame, stmt) -> StatementResult[types.TypeAttribute]: |
175 | 176 | @impl(py.GetItem) |
176 | 177 | def getitem( |
177 | 178 | self, |
178 | | - interp, |
| 179 | + interp: TypeInference, |
179 | 180 | frame: Frame[types.TypeAttribute], |
180 | 181 | stmt: py.GetItem, |
181 | 182 | ) -> StatementResult[types.TypeAttribute]: |
182 | 183 | obj = frame.get(stmt.obj) |
183 | | - if isinstance(obj, types.Hinted): # unwrap const |
| 184 | + if interp.is_const(obj): # unwrap const |
184 | 185 | obj = obj.type |
185 | 186 | index: types.TypeAttribute = frame.get(stmt.index) |
186 | 187 | # TODO: replace this when we can multiple dispatch |
@@ -221,30 +222,31 @@ def getitem_tuple( |
221 | 222 |
|
222 | 223 | def getitem_tuple_index( |
223 | 224 | self, |
224 | | - interp, |
| 225 | + interp: TypeInference, |
225 | 226 | stmt: py.GetItem, |
226 | 227 | obj: types.Generic, |
227 | 228 | index: types.TypeAttribute, |
228 | 229 | ): |
229 | | - if isinstance(index, types.Hinted): # const |
230 | | - if obj.vararg and index.data >= len(obj.vars): |
| 230 | + if interp.is_const(index) and index.type.is_subseteq(types.Int): |
| 231 | + index_: int = index.data.data |
| 232 | + if obj.vararg and index_ >= len(obj.vars): |
231 | 233 | return (obj.vararg.typ,) |
232 | | - elif index.data < len(obj.vars): |
233 | | - return (obj.vars[index.data],) |
| 234 | + elif index_ < len(obj.vars): |
| 235 | + return (obj.vars[index_],) |
234 | 236 | else: |
235 | 237 | return (types.Bottom,) |
236 | 238 | else: |
237 | 239 | return (self.getitem_tuple_union(obj),) |
238 | 240 |
|
239 | 241 | def getitem_tuple_slice( |
240 | 242 | self, |
241 | | - interp, |
| 243 | + interp: TypeInference, |
242 | 244 | stmt: py.GetItem, |
243 | 245 | obj: types.Generic, |
244 | 246 | index: types.TypeAttribute, |
245 | 247 | ): |
246 | | - if isinstance(index, types.Hinted): |
247 | | - data: slice = index.data |
| 248 | + if interp.is_const(index): |
| 249 | + data: slice = index.data.data |
248 | 250 | if obj.vararg and data.stop >= len(obj.vars): |
249 | 251 | return ( |
250 | 252 | types.Union( |
@@ -315,17 +317,15 @@ def new_list( |
315 | 317 |
|
316 | 318 | @impl(py.Slice) |
317 | 319 | def slice( |
318 | | - self, interp, frame: Frame[types.TypeAttribute], stmt: py.Slice |
| 320 | + self, interp: TypeInference, frame: Frame[types.TypeAttribute], stmt: py.Slice |
319 | 321 | ) -> StatementResult[types.TypeAttribute]: |
320 | 322 | start, stop, step = frame.get_values(stmt.args) |
321 | | - if ( |
322 | | - isinstance(start, types.Hinted) |
323 | | - and isinstance(stop, types.Hinted) |
324 | | - and isinstance(step, types.Hinted) |
325 | | - and isinstance(stmt.result.type, types.TypeAttribute) |
326 | | - ): |
| 323 | + if interp.is_const(start) and interp.is_const(stop) and interp.is_const(step): |
327 | 324 | return ( |
328 | | - types.Hinted(stmt.result.type, slice(start.data, stop.data, step.data)), |
| 325 | + types.Hinted( |
| 326 | + stmt.result.type, |
| 327 | + slice(start.data.data, stop.data.data, step.data.data), |
| 328 | + ), |
329 | 329 | ) |
330 | 330 |
|
331 | 331 | return (stmt.result.type,) |
0 commit comments