@@ -27,13 +27,21 @@ mod tensor_worker;
27
27
28
28
mod blocking;
29
29
mod panic;
30
+
31
+ use monarch_types:: py_global;
30
32
use pyo3:: prelude:: * ;
31
33
32
34
#[ pyfunction]
33
35
fn has_tensor_engine ( ) -> bool {
34
36
cfg ! ( feature = "tensor_engine" )
35
37
}
36
38
39
+ py_global ! (
40
+ add_extension_methods,
41
+ "monarch._src.actor.python_extension_methods" ,
42
+ "add_extension_methods"
43
+ ) ;
44
+
37
45
fn get_or_add_new_module < ' py > (
38
46
module : & Bound < ' py , PyModule > ,
39
47
module_name : & str ,
@@ -46,22 +54,29 @@ fn get_or_add_new_module<'py>(
46
54
if let Some ( submodule) = submodule {
47
55
current_module = submodule. extract ( ) ?;
48
56
} else {
49
- let new_module = PyModule :: new ( current_module. py ( ) , part) ?;
50
- current_module. add_submodule ( & new_module) ?;
57
+ let name = format ! ( "monarch._rust_bindings.{}" , parts. join( "." ) ) ;
58
+ let new_module = PyModule :: new ( current_module. py ( ) , & name) ?;
59
+ current_module. add ( part, new_module. clone ( ) ) ?;
51
60
current_module
52
61
. py ( )
53
62
. import ( "sys" ) ?
54
63
. getattr ( "modules" ) ?
55
- . set_item (
56
- format ! ( "monarch._rust_bindings.{}" , parts. join( "." ) ) ,
57
- new_module. clone ( ) ,
58
- ) ?;
64
+ . set_item ( name, new_module. clone ( ) ) ?;
59
65
current_module = new_module;
60
66
}
61
67
}
62
68
Ok ( current_module)
63
69
}
64
70
71
+ fn register < F > ( module : & Bound < ' _ , PyModule > , module_path : & str , register_fn : F ) -> PyResult < ( ) >
72
+ where
73
+ F : FnOnce ( & Bound < ' _ , PyModule > ) -> PyResult < ( ) > ,
74
+ {
75
+ let submodule = get_or_add_new_module ( module, module_path) ?;
76
+ register_fn ( & submodule) ?;
77
+ Ok ( ( ) )
78
+ }
79
+
65
80
#[ pymodule]
66
81
#[ pyo3( name = "_rust_bindings" ) ]
67
82
pub fn mod_init ( module : & Bound < ' _ , PyModule > ) -> PyResult < ( ) > {
@@ -71,153 +86,188 @@ pub fn mod_init(module: &Bound<'_, PyModule>) -> PyResult<()> {
71
86
runtime. handle ( ) . clone ( ) ,
72
87
Some ( :: hyperactor_mesh:: bootstrap:: BOOTSTRAP_INDEX_ENV . to_string ( ) ) ,
73
88
) ;
74
-
75
- monarch_hyperactor:: shape:: register_python_bindings ( & get_or_add_new_module (
89
+ register (
76
90
module,
77
91
"monarch_hyperactor.shape" ,
78
- ) ? ) ? ;
79
-
80
- monarch_hyperactor :: selection :: register_python_bindings ( & get_or_add_new_module (
92
+ monarch_hyperactor :: shape :: register_python_bindings ,
93
+ ) ? ;
94
+ register (
81
95
module,
82
96
"monarch_hyperactor.selection" ,
83
- ) ? ) ? ;
84
-
85
- monarch_hyperactor :: supervision :: register_python_bindings ( & get_or_add_new_module (
97
+ monarch_hyperactor :: selection :: register_python_bindings ,
98
+ ) ? ;
99
+ register (
86
100
module,
87
101
"monarch_hyperactor.supervision" ,
88
- ) ?) ?;
102
+ monarch_hyperactor:: supervision:: register_python_bindings,
103
+ ) ?;
89
104
90
105
#[ cfg( feature = "tensor_engine" ) ]
91
106
{
92
- client :: register_python_bindings ( & get_or_add_new_module (
107
+ register (
93
108
module,
94
109
"monarch_extension.client" ,
95
- ) ?) ?;
96
- tensor_worker:: register_python_bindings ( & get_or_add_new_module (
110
+ client:: register_python_bindings,
111
+ ) ?;
112
+ register (
97
113
module,
98
114
"monarch_extension.tensor_worker" ,
99
- ) ?) ?;
100
- controller:: register_python_bindings ( & get_or_add_new_module (
115
+ tensor_worker:: register_python_bindings,
116
+ ) ?;
117
+ register (
101
118
module,
102
119
"monarch_extension.controller" ,
103
- ) ?) ?;
104
- debugger:: register_python_bindings ( & get_or_add_new_module (
120
+ controller:: register_python_bindings,
121
+ ) ?;
122
+ register (
105
123
module,
106
124
"monarch_extension.debugger" ,
107
- ) ?) ?;
108
- monarch_messages:: debugger:: register_python_bindings ( & get_or_add_new_module (
125
+ debugger:: register_python_bindings,
126
+ ) ?;
127
+ register (
109
128
module,
110
129
"monarch_messages.debugger" ,
111
- ) ?) ?;
112
- simulator_client:: register_python_bindings ( & get_or_add_new_module (
130
+ monarch_messages:: debugger:: register_python_bindings,
131
+ ) ?;
132
+ register (
113
133
module,
114
134
"monarch_extension.simulator_client" ,
115
- ) ?) ?;
116
- :: controller:: bootstrap:: register_python_bindings ( & get_or_add_new_module (
135
+ simulator_client:: register_python_bindings,
136
+ ) ?;
137
+ register (
117
138
module,
118
139
"controller.bootstrap" ,
119
- ) ?) ?;
120
- :: monarch_tensor_worker:: bootstrap:: register_python_bindings ( & get_or_add_new_module (
140
+ :: controller:: bootstrap:: register_python_bindings,
141
+ ) ?;
142
+ register (
121
143
module,
122
144
"monarch_tensor_worker.bootstrap" ,
123
- ) ?) ?;
124
- crate :: convert:: register_python_bindings ( & get_or_add_new_module (
145
+ :: monarch_tensor_worker:: bootstrap:: register_python_bindings,
146
+ ) ?;
147
+ register (
125
148
module,
126
149
"monarch_extension.convert" ,
127
- ) ?) ?;
128
- crate :: mesh_controller:: register_python_bindings ( & get_or_add_new_module (
150
+ crate :: convert:: register_python_bindings,
151
+ ) ?;
152
+ register (
129
153
module,
130
154
"monarch_extension.mesh_controller" ,
131
- ) ?) ?;
132
- monarch_rdma_extension:: register_python_bindings ( & get_or_add_new_module ( module, "rdma" ) ?) ?;
155
+ crate :: mesh_controller:: register_python_bindings,
156
+ ) ?;
157
+ register (
158
+ module,
159
+ "rdma" ,
160
+ monarch_rdma_extension:: register_python_bindings,
161
+ ) ?;
133
162
}
134
- simulation_tools :: register_python_bindings ( & get_or_add_new_module (
163
+ register (
135
164
module,
136
165
"monarch_extension.simulation_tools" ,
137
- ) ?) ?;
138
- monarch_hyperactor:: bootstrap:: register_python_bindings ( & get_or_add_new_module (
166
+ simulation_tools:: register_python_bindings,
167
+ ) ?;
168
+ register (
139
169
module,
140
170
"monarch_hyperactor.bootstrap" ,
141
- ) ?) ?;
171
+ monarch_hyperactor:: bootstrap:: register_python_bindings,
172
+ ) ?;
142
173
143
- monarch_hyperactor :: proc :: register_python_bindings ( & get_or_add_new_module (
174
+ register (
144
175
module,
145
176
"monarch_hyperactor.proc" ,
146
- ) ?) ?;
177
+ monarch_hyperactor:: proc:: register_python_bindings,
178
+ ) ?;
147
179
148
- monarch_hyperactor :: actor :: register_python_bindings ( & get_or_add_new_module (
180
+ register (
149
181
module,
150
182
"monarch_hyperactor.actor" ,
151
- ) ?) ?;
183
+ monarch_hyperactor:: actor:: register_python_bindings,
184
+ ) ?;
152
185
153
- monarch_hyperactor :: pytokio :: register_python_bindings ( & get_or_add_new_module (
186
+ register (
154
187
module,
155
188
"monarch_hyperactor.pytokio" ,
156
- ) ? ) ? ;
157
-
158
- monarch_hyperactor :: mailbox :: register_python_bindings ( & get_or_add_new_module (
189
+ monarch_hyperactor :: pytokio :: register_python_bindings ,
190
+ ) ? ;
191
+ register (
159
192
module,
160
193
"monarch_hyperactor.mailbox" ,
161
- ) ?) ?;
194
+ monarch_hyperactor:: mailbox:: register_python_bindings,
195
+ ) ?;
162
196
163
- monarch_hyperactor :: alloc :: register_python_bindings ( & get_or_add_new_module (
197
+ register (
164
198
module,
165
199
"monarch_hyperactor.alloc" ,
166
- ) ?) ?;
167
- monarch_hyperactor:: channel:: register_python_bindings ( & get_or_add_new_module (
200
+ monarch_hyperactor:: alloc:: register_python_bindings,
201
+ ) ?;
202
+ register (
168
203
module,
169
204
"monarch_hyperactor.channel" ,
170
- ) ?) ?;
171
- monarch_hyperactor:: actor_mesh:: register_python_bindings ( & get_or_add_new_module (
205
+ monarch_hyperactor:: channel:: register_python_bindings,
206
+ ) ?;
207
+ register (
172
208
module,
173
209
"monarch_hyperactor.actor_mesh" ,
174
- ) ?) ?;
175
- monarch_hyperactor:: proc_mesh:: register_python_bindings ( & get_or_add_new_module (
210
+ monarch_hyperactor:: actor_mesh:: register_python_bindings,
211
+ ) ?;
212
+ register (
176
213
module,
177
214
"monarch_hyperactor.proc_mesh" ,
178
- ) ?) ?;
215
+ monarch_hyperactor:: proc_mesh:: register_python_bindings,
216
+ ) ?;
179
217
180
- monarch_hyperactor :: runtime :: register_python_bindings ( & get_or_add_new_module (
218
+ register (
181
219
module,
182
220
"monarch_hyperactor.runtime" ,
183
- ) ?) ?;
184
- monarch_hyperactor:: telemetry:: register_python_bindings ( & get_or_add_new_module (
221
+ monarch_hyperactor:: runtime:: register_python_bindings,
222
+ ) ?;
223
+ register (
185
224
module,
186
225
"monarch_hyperactor.telemetry" ,
187
- ) ?) ?;
188
- code_sync:: register_python_bindings ( & get_or_add_new_module (
226
+ monarch_hyperactor:: telemetry:: register_python_bindings,
227
+ ) ?;
228
+ register (
189
229
module,
190
230
"monarch_extension.code_sync" ,
191
- ) ?) ?;
231
+ code_sync:: register_python_bindings,
232
+ ) ?;
192
233
193
- crate :: panic :: register_python_bindings ( & get_or_add_new_module (
234
+ register (
194
235
module,
195
236
"monarch_extension.panic" ,
196
- ) ?) ?;
237
+ crate :: panic:: register_python_bindings,
238
+ ) ?;
197
239
198
- crate :: blocking :: register_python_bindings ( & get_or_add_new_module (
240
+ register (
199
241
module,
200
242
"monarch_extension.blocking" ,
201
- ) ?) ?;
243
+ crate :: blocking:: register_python_bindings,
244
+ ) ?;
202
245
203
- crate :: logging :: register_python_bindings ( & get_or_add_new_module (
246
+ register (
204
247
module,
205
248
"monarch_extension.logging" ,
206
- ) ?) ?;
249
+ crate :: logging:: register_python_bindings,
250
+ ) ?;
207
251
208
252
#[ cfg( fbcode_build) ]
209
253
{
210
- monarch_hyperactor :: meta :: alloc :: register_python_bindings ( & get_or_add_new_module (
254
+ register (
211
255
module,
212
256
"monarch_hyperactor.meta.alloc" ,
213
- ) ?) ?;
214
- monarch_hyperactor:: meta:: alloc_mock:: register_python_bindings ( & get_or_add_new_module (
257
+ monarch_hyperactor:: meta:: alloc:: register_python_bindings,
258
+ ) ?;
259
+ register (
215
260
module,
216
261
"monarch_hyperactor.meta.alloc_mock" ,
217
- ) ?) ?;
262
+ monarch_hyperactor:: meta:: alloc_mock:: register_python_bindings,
263
+ ) ?;
218
264
}
219
265
// Add feature detection function
220
266
module. add_function ( wrap_pyfunction ! ( has_tensor_engine, module) ?) ?;
221
267
268
+ // this should be called last. otherwise cross references in pyi files will not have been
269
+ // added to sys.modules yet.
270
+ add_extension_methods ( module. py ( ) ) . call1 ( ( module, ) ) ?;
271
+
222
272
Ok ( ( ) )
223
273
}
0 commit comments