16
16
from azure .ai .ml .entities ._assets .federated_learning_silo import FederatedLearningSilo
17
17
from azure .ai .ml .entities ._component .component import Component
18
18
from azure .ai .ml .entities ._validation import MutableValidationResult
19
- from .subcomponents import aggregate_output
19
+ from azure .ai .ml .dsl ._do_while import do_while
20
+ from azure .ai .ml .dsl import pipeline
21
+ from .subcomponents import create_scatter_output_table
22
+
20
23
21
24
# TODO 2293610: add support for more types of outputs besides uri_folder and mltable
22
25
# Likely types that ought to be mergeable: string, int, uri_file
23
26
MERGE_COMPONENT_MAPPING = {
24
- "mltable" : aggregate_output ,
25
- "uri_folder" : aggregate_output ,
27
+ "mltable" : create_scatter_output_table ,
28
+ "uri_folder" : create_scatter_output_table ,
26
29
}
27
30
28
31
34
37
# big TODO: For some reason, surfacing this file in __init__.py causes
35
38
# a circular import exception on the first attempted import
36
39
# In notebooks, the second import succeeds, but then causes a silent failure where the
37
- # MLDesigner component created by the subcomponents.aggregate_output function
40
+ # MLDesigner component created by the subcomponents.create_scatter_output_table function
38
41
# will produce a ComponentExecutor object instead of the actual component.
39
42
# TODO 2293541: Add telemetry of some sort
40
43
# pylint: disable=too-many-instance-attributes
@@ -101,37 +104,20 @@ def __init__(
101
104
self .silo_to_aggregation_argument_map = silo_to_aggregation_argument_map
102
105
self .aggregation_to_silo_argument_map = aggregation_to_silo_argument_map
103
106
self .max_iterations = max_iterations
104
- # subgraph is a list of all iteration's executed steps, stored in a dictionary for each iteration
105
- self .subgraph = []
106
107
self ._init = True # Needed by parent class to work properly
107
108
108
- executed_aggregation_step = None
109
- # TODO 2293573: replace this for-loop with a do-while node.
110
- for _ in range (self .max_iterations ):
111
- # Create inputs for silo components
112
- silo_inputs = {}
113
- # Start with static inputs
114
- silo_inputs .update (self .shared_silo_kwargs )
115
- # merge in inputs passed in from previous iteration's aggregate step if they exist
116
- if executed_aggregation_step is not None and self .aggregation_to_silo_argument_map is not None :
117
- silo_inputs .update (
118
- FLScatterGather ._extract_outputs (
119
- executed_aggregation_step .outputs , self .aggregation_to_silo_argument_map
120
- )
121
- )
122
- # per-silo inputs added in during scatter_gather
123
- # Run scatter-gather iteration.
124
- self .subgraph .append (self .scatter_gather (silo_inputs ))
125
- # re-assign previous agg step to extract outputs for next iter's inputs
126
- executed_aggregation_step = self .subgraph [- 1 ]["aggregation" ]
109
+ self .scatter_gather_graph = self .scatter_gather ()
110
+
111
+ # set SG node flag for telemetry
112
+ self .scatter_gather_graph .properties ["azureml.telemetry.attribute" ] = "scatter-gather"
127
113
128
114
# set output to final aggregation step's output
129
- self ._outputs = self .subgraph [ - 1 ][ "aggregation" ] .outputs
115
+ self ._outputs = self .scatter_gather_graph .outputs
130
116
super (FLScatterGather , self ).__init__ (
131
117
type = JobType .COMPONENT , # pylint: disable=redefined-builtin
132
118
component = None ,
133
119
inputs = None ,
134
- outputs = self .subgraph [ - 1 ][ "aggregation" ] .outputs ,
120
+ outputs = self .scatter_gather_graph .outputs ,
135
121
name = None ,
136
122
display_name = None ,
137
123
description = None ,
@@ -142,14 +128,117 @@ def __init__(
142
128
experiment_name = None ,
143
129
)
144
130
131
+ def scatter_gather (self ):
132
+ @pipeline (
133
+ name = "Scatter gather" ,
134
+ description = "It includes all steps that need to be executed in silo and aggregation" ,
135
+ )
136
+ def scatter_gather_iteration_body (** silo_inputs ):
137
+ """
138
+ Performs a scatter-gather iteration by running copies of the silo step on different
139
+ computes/datstores according to this node's silo configs. The outputs of these
140
+ silo components are then merged by an internal helper component. The merged values
141
+ are then inputted into the user-provided aggregation component. Returns the executed aggregation component.
142
+ Args:
143
+ silo_inputs (dict): A dictionary of names and Inputs to be injected into each executed silo step.
144
+ This dictionary is merged with silo-specific inputs before each executed.
145
+ """
146
+
147
+ silo_outputs = []
148
+ # TODO 2293586 replace this for-loop with a parallel-for node
149
+ for silo_config in self .silo_configs :
150
+ silo_inputs .update (silo_config .inputs )
151
+ executed_silo_component = self .silo_component (** silo_inputs )
152
+ for v , k in executed_silo_component .inputs .items ():
153
+ if v in silo_config .inputs and k .type == "uri_folder" :
154
+ k .mode = "ro_mount"
155
+ FLScatterGather ._anchor_step (
156
+ pipeline_step = executed_silo_component ,
157
+ compute = silo_config .compute ,
158
+ internal_datastore = silo_config .datastore ,
159
+ orchestrator_datastore = self .aggregation_datastore ,
160
+ )
161
+ # add to silo outputs list
162
+ silo_outputs .append (executed_silo_component )
163
+
164
+ # produce internal argument-merging components and record them in local subgraph
165
+ merge_comp_mapping = self ._inject_merge_components (silo_outputs )
166
+
167
+ # produce aggregate step inputs by merging static kwargs and mapped arguments from
168
+ # internal merge components
169
+ agg_inputs = {}
170
+ agg_inputs .update (self .aggregation_kwargs )
171
+ internal_merge_outputs = {
172
+ self ._get_aggregator_input_name (k ): v .outputs .aggregated_output for k , v in merge_comp_mapping .items ()
173
+ }
174
+ agg_inputs .update (internal_merge_outputs )
175
+
176
+ # run the user aggregation step
177
+ executed_aggregation_component = self .aggregation_component (** agg_inputs )
178
+ # Set mode of aggregated mltable inputs as eval mount to allow files referenced within the table
179
+ # to be accessible by the component
180
+ for name , agg_input in executed_aggregation_component .inputs .items ():
181
+ if name in self .silo_to_aggregation_argument_map .keys () and agg_input .type == "mltable" :
182
+ agg_input .mode = "eval_download"
183
+
184
+ # Anchor both the internal merge components and the user-supplied aggregation step
185
+ # to the aggregation compute and datastore
186
+ if self .aggregation_compute is not None and self .aggregation_datastore is not None :
187
+ # internal merge component is also siloed to wherever the aggregation component lives.
188
+ for executed_merge_component in merge_comp_mapping .values ():
189
+ FLScatterGather ._anchor_step (
190
+ pipeline_step = executed_merge_component ,
191
+ compute = self .aggregation_compute ,
192
+ internal_datastore = self .aggregation_datastore ,
193
+ orchestrator_datastore = self .aggregation_datastore ,
194
+ )
195
+ FLScatterGather ._anchor_step (
196
+ pipeline_step = executed_aggregation_component ,
197
+ compute = self .aggregation_compute ,
198
+ internal_datastore = self .aggregation_datastore ,
199
+ orchestrator_datastore = self .aggregation_datastore ,
200
+ )
201
+ return executed_aggregation_component .outputs
202
+
203
+ @pipeline (name = "Scatter gather graph" )
204
+ def create_scatter_gather_graph ():
205
+ """
206
+ Creates a scatter-gather graph by executing the scatter_gather_iteration_body
207
+ function in a do-while loop. The loop terminates when the user-supplied
208
+ termination condition is met.
209
+ """
210
+
211
+ silo_inputs = {}
212
+ # Start with static inputs
213
+ silo_inputs .update (self .shared_silo_kwargs )
214
+
215
+ # merge in inputs passed in from previous iteration's aggregate step)
216
+ if self .aggregation_to_silo_argument_map is not None :
217
+ silo_inputs .update ({v : None for v in self .aggregation_to_silo_argument_map .values ()})
218
+
219
+ scatter_gather_body = scatter_gather_iteration_body (** silo_inputs )
220
+
221
+ # map aggregation outputs to scatter inputs
222
+ do_while_mapping = {
223
+ k : getattr (scatter_gather_body .inputs , v ) for k , v in self .aggregation_to_silo_argument_map .items ()
224
+ }
225
+
226
+ do_while (
227
+ body = scatter_gather_body ,
228
+ mapping = do_while_mapping ,
229
+ max_iteration_count = self .max_iterations ,
230
+ )
231
+ return scatter_gather_body .outputs
232
+
233
+ return create_scatter_gather_graph ()
234
+
145
235
# TODO potential set default fail_on_missing value to false
146
236
@classmethod
147
237
def _extract_outputs (cls , component_output : Output , argument_map : Dict , fail_on_missing = False ):
148
238
"""
149
239
Pulls values from a component_output, as specified by the keys of the
150
240
inputted argument_map, and groups in a new dictionary. The keys of the new dictionary
151
241
are specified by the items of the argument_map dictionary.
152
-
153
242
Example
154
243
component_output = {"one" : 1, "two": 2, "three" : 3}
155
244
argument_map = {"one" : "red", "two" : "two"}
@@ -172,99 +261,13 @@ def _extract_outputs(cls, component_output: Output, argument_map: Dict, fail_on_
172
261
result [v ] = component_output [k ]
173
262
return result
174
263
175
- def scatter_gather (self , silo_inputs : Dict ):
176
- """
177
- Performs a scatter-gather iteration by running copies of the silo step on different
178
- computes/datstores according to this node's silo configs. The outputs of these
179
- silo components are then merged by an internal helper component. The merged values
180
- are then inputted into the user-provided aggregation component. Returns the executed aggregation component.
181
-
182
- Args:
183
- silo_inputs (dict): A dictionary of names and Inputs to be injected into each executed silo step.
184
- This dictionary is merged with silo-specific inputs before each executed.
185
- Returns:
186
- sg_graph (dict): A dictionary containing the subgraph of this scatter-gather iteration.
187
- Contains three indexes which contain a list of executed silo steps, a list of internally
188
- created and executed merge components, and the executed aggregation step.
189
- """
190
-
191
- sg_graph = {"silo_steps" : []}
192
- siloed_outputs = {}
193
- # TODO 2293586 replace this for-loop with a parallel-for node
194
- # pylint: disable=consider-using-enumerate
195
- for i in range (len (self .silo_configs )):
196
- silo_config = self .silo_configs [i ]
197
- silo_inputs .update (silo_config .inputs )
198
- executed_silo_component = self .silo_component (** silo_inputs )
199
- for v , k in executed_silo_component .inputs .items ():
200
- if v in silo_config .inputs and k .type == "uri_folder" :
201
- k .mode = "ro_mount"
202
- FLScatterGather ._anchor_step (
203
- pipeline_step = executed_silo_component ,
204
- compute = silo_config .compute ,
205
- internal_datastore = silo_config .datastore ,
206
- orchestrator_datastore = self .aggregation_datastore ,
207
- )
208
- # include executed silo steps in recorded subgraph
209
- sg_graph ["silo_steps" ].append (executed_silo_component )
210
-
211
- # Extract user-specified outputs from the silo component, rename them as needed,
212
- # annotate them with the silo's index, then jam them all into the
213
- # variable-length internal component's input list.
214
- siloed_outputs .update (
215
- {
216
- "{}_{}" .format (k , i ): v
217
- for k , v in FLScatterGather ._extract_outputs (
218
- executed_silo_component .outputs , self .silo_to_aggregation_argument_map
219
- ).items ()
220
- }
221
- )
222
-
223
- # produce internal argument-merging components and record them in local subgraph
224
- merge_comp_mapping = self ._inject_merge_components (sg_graph ["silo_steps" ])
225
- sg_graph ["mergers" ] = list (merge_comp_mapping .values ())
226
-
227
- # produce aggregate step inputs by merging static kwargs and mapped arguments from
228
- # internal merge components
229
- agg_inputs = {}
230
- agg_inputs .update (self .aggregation_kwargs )
231
- internal_merge_outputs = {
232
- self ._get_aggregator_input_name (k ): v .outputs .aggregated_output for k , v in merge_comp_mapping .items ()
233
- }
234
- agg_inputs .update (internal_merge_outputs )
235
-
236
- # run the user aggregation step
237
- executed_aggregation_component = self .aggregation_component (** agg_inputs )
238
- # Set mode of aggregated mltable inputs as eval mount to allow files referenced within the table
239
- # to be accessible by the component
240
- for name , agg_input in executed_aggregation_component .inputs .items ():
241
- if name in self .silo_to_aggregation_argument_map .keys () and agg_input .type == "mltable" :
242
- agg_input .mode = "eval_download"
243
- # record aggregation step in subgraph
244
- sg_graph ["aggregation" ] = executed_aggregation_component
245
-
246
- # Anchor both the internal merge components and the user-supplied aggregation step
247
- # to the aggregation compute and datastore
248
- if self .aggregation_compute is not None and self .aggregation_datastore is not None :
249
- # internal merge component is also siloed to wherever the aggregation component lives.
250
- for executed_merge_component in merge_comp_mapping .values ():
251
- FLScatterGather ._anchor_step (
252
- pipeline_step = executed_merge_component ,
253
- compute = self .aggregation_compute ,
254
- internal_datastore = self .aggregation_datastore ,
255
- orchestrator_datastore = self .aggregation_datastore ,
256
- )
257
- FLScatterGather ._anchor_step (
258
- pipeline_step = executed_aggregation_component ,
259
- compute = self .aggregation_compute ,
260
- internal_datastore = self .aggregation_datastore ,
261
- orchestrator_datastore = self .aggregation_datastore ,
262
- )
263
- return sg_graph
264
-
265
264
@classmethod
266
265
def _get_fl_datastore_path (
267
- cls , datastore_name : str , output_name : str , unique_id : str = "${{name}}" , iteration_num : Optional [int ] = None
266
+ cls ,
267
+ datastore_name : str ,
268
+ output_name : str ,
269
+ unique_id : str = "${{name}}" ,
270
+ iteration_num : Optional [int ] = None ,
268
271
) -> str :
269
272
"""Construct a path string using the inputted values. The important aspect is that this produces a
270
273
path with a specified datastore.
@@ -434,7 +437,10 @@ def _anchor_step(
434
437
if output .type in ANCHORABLE_OUTPUT_TYPES :
435
438
validation_result .merge_with (
436
439
cls ._check_or_set_datastore (
437
- name = name , output = output , target_datastore = orchestrator_datastore , iteration_num = iteration
440
+ name = name ,
441
+ output = output ,
442
+ target_datastore = orchestrator_datastore ,
443
+ iteration_num = iteration ,
438
444
)
439
445
)
440
446
else :
@@ -463,7 +469,6 @@ def validate_inputs(
463
469
max_iterations : int ,
464
470
raise_error = False ,
465
471
) -> MutableValidationResult :
466
-
467
472
validation_result = cls ._create_empty_validation_result ()
468
473
469
474
# saved values for validation later on
@@ -474,7 +479,8 @@ def validate_inputs(
474
479
# validate silo component
475
480
if silo_component is None :
476
481
validation_result .append_error (
477
- yaml_path = "silo_component" , message = "silo_component is a required argument for the scatter gather node."
482
+ yaml_path = "silo_component" ,
483
+ message = "silo_component is a required argument for the scatter gather node." ,
478
484
)
479
485
else :
480
486
# ensure that silo component has both inputs and outputs
@@ -522,10 +528,14 @@ def validate_inputs(
522
528
# validate silos configs
523
529
if silo_configs is None :
524
530
validation_result .append_error (
525
- yaml_path = "silo_configs" , message = "silo_configs is a required argument for the scatter gather node."
531
+ yaml_path = "silo_configs" ,
532
+ message = "silo_configs is a required argument for the scatter gather node." ,
526
533
)
527
534
elif len (silo_configs ) == 0 :
528
- validation_result .append_error (yaml_path = "silo_configs" , message = "silo_configs cannot be an empty list." )
535
+ validation_result .append_error (
536
+ yaml_path = "silo_configs" ,
537
+ message = "silo_configs cannot be an empty list." ,
538
+ )
529
539
else :
530
540
first_silo = silo_configs [0 ]
531
541
expected_inputs = []
@@ -646,7 +656,8 @@ def validate_inputs(
646
656
647
657
if max_iterations < 1 :
648
658
validation_result .append_error (
649
- yaml_path = "max_iterations" , message = f"max_iterations must be a positive value, not '{ max_iterations } '."
659
+ yaml_path = "max_iterations" ,
660
+ message = f"max_iterations must be a positive value, not '{ max_iterations } '." ,
650
661
)
651
662
652
663
return validation_result .try_raise (
@@ -656,7 +667,11 @@ def validate_inputs(
656
667
657
668
@classmethod
658
669
def _custom_fl_data_path (
659
- cls , datastore_name , output_name , unique_id = "${{name}}" , iteration_num = "${{iteration_num}}"
670
+ cls ,
671
+ datastore_name ,
672
+ output_name ,
673
+ unique_id = "${{name}}" ,
674
+ iteration_num = "${{iteration_num}}" ,
660
675
):
661
676
"""Produces a path to store the data during FL training.
662
677
Args:
@@ -682,7 +697,11 @@ def _get_aggregator_input_name(self, silo_output_name):
682
697
683
698
@classmethod
684
699
def _try_create_default_mappings (
685
- cls , silo_comp : Component , agg_comp : Component , silo_agg_map : Dict , agg_silo_map : Dict
700
+ cls ,
701
+ silo_comp : Component ,
702
+ agg_comp : Component ,
703
+ silo_agg_map : Dict ,
704
+ agg_silo_map : Dict ,
686
705
):
687
706
"""
688
707
This function tries to produce dictionaries that link the silo and aggregation
@@ -739,7 +758,10 @@ def _inject_merge_components(self, executed_silo_components):
739
758
executed_component = executed_silo_components [0 ]
740
759
741
760
merge_comp_mapping = {}
742
- for silo_output_argument_name , _ in self .silo_to_aggregation_argument_map .items ():
761
+ for (
762
+ silo_output_argument_name ,
763
+ _ ,
764
+ ) in self .silo_to_aggregation_argument_map .items ():
743
765
merge_comp = self ._get_merge_component (executed_component .outputs [silo_output_argument_name ].type )
744
766
merge_component_inputs = {
745
767
silo_output_argument_name
0 commit comments