38
38
39
39
opaque = lambda dialect_namespace , buffer : OpaqueType .get (dialect_namespace , buffer )
40
40
41
- range_ = for_
42
41
42
+ def canonicalize_start_stop_step (start , stop , step ):
43
+ if step is None :
44
+ step = 1
45
+ if stop is None :
46
+ stop = start
47
+ start = 0
48
+ params = [start , stop , step ]
49
+ type = IndexType .get ()
50
+ maybe_types = {p .type for p in params if isinstance (p , Value )}
51
+ if maybe_types :
52
+ if len (maybe_types ) > 1 :
53
+ raise ValueError (
54
+ f"all { start = } and { stop = } and { step = } ir.Value objects must have the same type"
55
+ )
56
+ type = maybe_types .pop ()
43
57
44
- def placeholder_opaque_t ():
45
- return opaque ("scf" , "placeholder" )
58
+ for i , p in enumerate (params ):
59
+ if isinstance (p , int ):
60
+ p = _ext_arith_constant (p , type = type )
61
+ assert isinstance (p , Value )
62
+ params [i ] = p
63
+
64
+ return params [0 ], params [1 ], params [2 ]
46
65
47
66
48
- def _for (
67
+ def _build_for (
49
68
start ,
50
69
stop = None ,
51
70
step = None ,
@@ -54,25 +73,39 @@ def _for(
54
73
loc = None ,
55
74
ip = None ,
56
75
):
57
- if step is None :
58
- step = 1
59
- if stop is None :
60
- stop = start
61
- start = 0
62
- params = [start , stop , step ]
63
- for i , p in enumerate (params ):
64
- if isinstance (p , int ):
65
- p = _ext_arith_constant (p , index = True )
66
- if not _is_index_type (p .type ):
67
- p = index_cast (p )
68
- params [i ] = p
76
+ start , stop , step = canonicalize_start_stop_step (start , stop , step )
77
+ return ForOp (start , stop , step , iter_args , loc = loc , ip = ip )
69
78
79
+
80
+ def range_ (
81
+ start ,
82
+ stop = None ,
83
+ step = None ,
84
+ iter_args : Optional [Sequence [Value ]] = None ,
85
+ * ,
86
+ loc = None ,
87
+ ip = None ,
88
+ ):
70
89
if loc is None :
71
90
loc = get_user_code_loc ()
72
- return ForOp (* params , iter_args , loc = loc , ip = ip )
91
+
92
+ for_op = _build_for (start , stop , step , iter_args , loc = loc , ip = ip )
93
+ iv = for_op .induction_variable
94
+ iter_args = tuple (for_op .inner_iter_args )
95
+ with InsertionPoint (for_op .body ):
96
+ if len (iter_args ) > 1 :
97
+ yield iv , iter_args , for_op .results
98
+ elif len (iter_args ) == 1 :
99
+ yield iv , iter_args [0 ], for_op .results [0 ]
100
+ else :
101
+ yield iv
102
+
103
+
104
+ def placeholder_opaque_t ():
105
+ return opaque ("scf" , "placeholder" )
73
106
74
107
75
- for_ = region_op (_for , terminator = yield__ )
108
+ for_ = region_op (_build_for , terminator = yield__ )
76
109
77
110
78
111
@_cext .register_operation (_Dialect , replace = True )
0 commit comments