Skip to content

Commit 0ca4e7e

Browse files
garg-amitMilesHollandsingankitkdestin
authored
Replace the for loop in the Scatter Gather node with a do_while loop. (Azure#29633)
* add temp files * more stuff * p-for * add files before switching branches * finish up skeleton code * save progress * lots of stuff * remove pdb statements * temp nodes * progress pre-main-merge * make things work * Changes to aggreagate outputs * Merge, and lots of other stuff * remove test nodes and add real output * mark node as experimental * nit fixes * Removing dependency on PyYaml * nits * linting nits general pass * lint pass v2 * disable validation * better datstore anchoring and testing * Update sdk/ml/azure-ai-ml/azure/ai/ml/entities/_builders/fl_scatter_gather.py Co-authored-by: kdestin <[email protected]> * PR comments * testing * remove pdb * lint and dev require mldesigner * spell write * lint * more lint * fix lint-suggested breaking changes * regenerate test recording * minor test fix * run black formatter but for real * cl and package dep warning * move test constants to fixtures, make mldesigner import checked in code * formatting * rerun black after merge * implement do_while * import from dsl module * add pipeline comp * update unitest and e2etest * remove redundant comments * set telemetry flag for SG node * formatting + add e2etesting and * resolve pylint issues * fix pylint and black issues --------- Co-authored-by: Miles Holland <[email protected]> Co-authored-by: MilesHolland <[email protected]> Co-authored-by: singankit <[email protected]> Co-authored-by: kdestin <[email protected]>
1 parent ec53cbd commit 0ca4e7e

File tree

4 files changed

+243
-197
lines changed

4 files changed

+243
-197
lines changed

sdk/ml/azure-ai-ml/azure/ai/ml/entities/_builders/fl_scatter_gather.py

Lines changed: 150 additions & 128 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,16 @@
1616
from azure.ai.ml.entities._assets.federated_learning_silo import FederatedLearningSilo
1717
from azure.ai.ml.entities._component.component import Component
1818
from azure.ai.ml.entities._validation import MutableValidationResult
19-
from .subcomponents import aggregate_output
19+
from azure.ai.ml.dsl._do_while import do_while
20+
from azure.ai.ml.dsl import pipeline
21+
from .subcomponents import create_scatter_output_table
22+
2023

2124
# TODO 2293610: add support for more types of outputs besides uri_folder and mltable
2225
# Likely types that ought to be mergeable: string, int, uri_file
2326
MERGE_COMPONENT_MAPPING = {
24-
"mltable": aggregate_output,
25-
"uri_folder": aggregate_output,
27+
"mltable": create_scatter_output_table,
28+
"uri_folder": create_scatter_output_table,
2629
}
2730

2831

@@ -34,7 +37,7 @@
3437
# big TODO: For some reason, surfacing this file in __init__.py causes
3538
# a circular import exception on the first attempted import
3639
# In notebooks, the second import succeeds, but then causes a silent failure where the
37-
# MLDesigner component created by the subcomponents.aggregate_output function
40+
# MLDesigner component created by the subcomponents.create_scatter_output_table function
3841
# will produce a ComponentExecutor object instead of the actual component.
3942
# TODO 2293541: Add telemetry of some sort
4043
# pylint: disable=too-many-instance-attributes
@@ -101,37 +104,20 @@ def __init__(
101104
self.silo_to_aggregation_argument_map = silo_to_aggregation_argument_map
102105
self.aggregation_to_silo_argument_map = aggregation_to_silo_argument_map
103106
self.max_iterations = max_iterations
104-
# subgraph is a list of all iteration's executed steps, stored in a dictionary for each iteration
105-
self.subgraph = []
106107
self._init = True # Needed by parent class to work properly
107108

108-
executed_aggregation_step = None
109-
# TODO 2293573: replace this for-loop with a do-while node.
110-
for _ in range(self.max_iterations):
111-
# Create inputs for silo components
112-
silo_inputs = {}
113-
# Start with static inputs
114-
silo_inputs.update(self.shared_silo_kwargs)
115-
# merge in inputs passed in from previous iteration's aggregate step if they exist
116-
if executed_aggregation_step is not None and self.aggregation_to_silo_argument_map is not None:
117-
silo_inputs.update(
118-
FLScatterGather._extract_outputs(
119-
executed_aggregation_step.outputs, self.aggregation_to_silo_argument_map
120-
)
121-
)
122-
# per-silo inputs added in during scatter_gather
123-
# Run scatter-gather iteration.
124-
self.subgraph.append(self.scatter_gather(silo_inputs))
125-
# re-assign previous agg step to extract outputs for next iter's inputs
126-
executed_aggregation_step = self.subgraph[-1]["aggregation"]
109+
self.scatter_gather_graph = self.scatter_gather()
110+
111+
# set SG node flag for telemetry
112+
self.scatter_gather_graph.properties["azureml.telemetry.attribute"] = "scatter-gather"
127113

128114
# set output to final aggregation step's output
129-
self._outputs = self.subgraph[-1]["aggregation"].outputs
115+
self._outputs = self.scatter_gather_graph.outputs
130116
super(FLScatterGather, self).__init__(
131117
type=JobType.COMPONENT, # pylint: disable=redefined-builtin
132118
component=None,
133119
inputs=None,
134-
outputs=self.subgraph[-1]["aggregation"].outputs,
120+
outputs=self.scatter_gather_graph.outputs,
135121
name=None,
136122
display_name=None,
137123
description=None,
@@ -142,14 +128,117 @@ def __init__(
142128
experiment_name=None,
143129
)
144130

131+
def scatter_gather(self):
132+
@pipeline(
133+
name="Scatter gather",
134+
description="It includes all steps that need to be executed in silo and aggregation",
135+
)
136+
def scatter_gather_iteration_body(**silo_inputs):
137+
"""
138+
Performs a scatter-gather iteration by running copies of the silo step on different
139+
computes/datstores according to this node's silo configs. The outputs of these
140+
silo components are then merged by an internal helper component. The merged values
141+
are then inputted into the user-provided aggregation component. Returns the executed aggregation component.
142+
Args:
143+
silo_inputs (dict): A dictionary of names and Inputs to be injected into each executed silo step.
144+
This dictionary is merged with silo-specific inputs before each executed.
145+
"""
146+
147+
silo_outputs = []
148+
# TODO 2293586 replace this for-loop with a parallel-for node
149+
for silo_config in self.silo_configs:
150+
silo_inputs.update(silo_config.inputs)
151+
executed_silo_component = self.silo_component(**silo_inputs)
152+
for v, k in executed_silo_component.inputs.items():
153+
if v in silo_config.inputs and k.type == "uri_folder":
154+
k.mode = "ro_mount"
155+
FLScatterGather._anchor_step(
156+
pipeline_step=executed_silo_component,
157+
compute=silo_config.compute,
158+
internal_datastore=silo_config.datastore,
159+
orchestrator_datastore=self.aggregation_datastore,
160+
)
161+
# add to silo outputs list
162+
silo_outputs.append(executed_silo_component)
163+
164+
# produce internal argument-merging components and record them in local subgraph
165+
merge_comp_mapping = self._inject_merge_components(silo_outputs)
166+
167+
# produce aggregate step inputs by merging static kwargs and mapped arguments from
168+
# internal merge components
169+
agg_inputs = {}
170+
agg_inputs.update(self.aggregation_kwargs)
171+
internal_merge_outputs = {
172+
self._get_aggregator_input_name(k): v.outputs.aggregated_output for k, v in merge_comp_mapping.items()
173+
}
174+
agg_inputs.update(internal_merge_outputs)
175+
176+
# run the user aggregation step
177+
executed_aggregation_component = self.aggregation_component(**agg_inputs)
178+
# Set mode of aggregated mltable inputs as eval mount to allow files referenced within the table
179+
# to be accessible by the component
180+
for name, agg_input in executed_aggregation_component.inputs.items():
181+
if name in self.silo_to_aggregation_argument_map.keys() and agg_input.type == "mltable":
182+
agg_input.mode = "eval_download"
183+
184+
# Anchor both the internal merge components and the user-supplied aggregation step
185+
# to the aggregation compute and datastore
186+
if self.aggregation_compute is not None and self.aggregation_datastore is not None:
187+
# internal merge component is also siloed to wherever the aggregation component lives.
188+
for executed_merge_component in merge_comp_mapping.values():
189+
FLScatterGather._anchor_step(
190+
pipeline_step=executed_merge_component,
191+
compute=self.aggregation_compute,
192+
internal_datastore=self.aggregation_datastore,
193+
orchestrator_datastore=self.aggregation_datastore,
194+
)
195+
FLScatterGather._anchor_step(
196+
pipeline_step=executed_aggregation_component,
197+
compute=self.aggregation_compute,
198+
internal_datastore=self.aggregation_datastore,
199+
orchestrator_datastore=self.aggregation_datastore,
200+
)
201+
return executed_aggregation_component.outputs
202+
203+
@pipeline(name="Scatter gather graph")
204+
def create_scatter_gather_graph():
205+
"""
206+
Creates a scatter-gather graph by executing the scatter_gather_iteration_body
207+
function in a do-while loop. The loop terminates when the user-supplied
208+
termination condition is met.
209+
"""
210+
211+
silo_inputs = {}
212+
# Start with static inputs
213+
silo_inputs.update(self.shared_silo_kwargs)
214+
215+
# merge in inputs passed in from previous iteration's aggregate step)
216+
if self.aggregation_to_silo_argument_map is not None:
217+
silo_inputs.update({v: None for v in self.aggregation_to_silo_argument_map.values()})
218+
219+
scatter_gather_body = scatter_gather_iteration_body(**silo_inputs)
220+
221+
# map aggregation outputs to scatter inputs
222+
do_while_mapping = {
223+
k: getattr(scatter_gather_body.inputs, v) for k, v in self.aggregation_to_silo_argument_map.items()
224+
}
225+
226+
do_while(
227+
body=scatter_gather_body,
228+
mapping=do_while_mapping,
229+
max_iteration_count=self.max_iterations,
230+
)
231+
return scatter_gather_body.outputs
232+
233+
return create_scatter_gather_graph()
234+
145235
# TODO potential set default fail_on_missing value to false
146236
@classmethod
147237
def _extract_outputs(cls, component_output: Output, argument_map: Dict, fail_on_missing=False):
148238
"""
149239
Pulls values from a component_output, as specified by the keys of the
150240
inputted argument_map, and groups in a new dictionary. The keys of the new dictionary
151241
are specified by the items of the argument_map dictionary.
152-
153242
Example
154243
component_output = {"one" : 1, "two": 2, "three" : 3}
155244
argument_map = {"one" : "red", "two" : "two"}
@@ -172,99 +261,13 @@ def _extract_outputs(cls, component_output: Output, argument_map: Dict, fail_on_
172261
result[v] = component_output[k]
173262
return result
174263

175-
def scatter_gather(self, silo_inputs: Dict):
176-
"""
177-
Performs a scatter-gather iteration by running copies of the silo step on different
178-
computes/datstores according to this node's silo configs. The outputs of these
179-
silo components are then merged by an internal helper component. The merged values
180-
are then inputted into the user-provided aggregation component. Returns the executed aggregation component.
181-
182-
Args:
183-
silo_inputs (dict): A dictionary of names and Inputs to be injected into each executed silo step.
184-
This dictionary is merged with silo-specific inputs before each executed.
185-
Returns:
186-
sg_graph (dict): A dictionary containing the subgraph of this scatter-gather iteration.
187-
Contains three indexes which contain a list of executed silo steps, a list of internally
188-
created and executed merge components, and the executed aggregation step.
189-
"""
190-
191-
sg_graph = {"silo_steps": []}
192-
siloed_outputs = {}
193-
# TODO 2293586 replace this for-loop with a parallel-for node
194-
# pylint: disable=consider-using-enumerate
195-
for i in range(len(self.silo_configs)):
196-
silo_config = self.silo_configs[i]
197-
silo_inputs.update(silo_config.inputs)
198-
executed_silo_component = self.silo_component(**silo_inputs)
199-
for v, k in executed_silo_component.inputs.items():
200-
if v in silo_config.inputs and k.type == "uri_folder":
201-
k.mode = "ro_mount"
202-
FLScatterGather._anchor_step(
203-
pipeline_step=executed_silo_component,
204-
compute=silo_config.compute,
205-
internal_datastore=silo_config.datastore,
206-
orchestrator_datastore=self.aggregation_datastore,
207-
)
208-
# include executed silo steps in recorded subgraph
209-
sg_graph["silo_steps"].append(executed_silo_component)
210-
211-
# Extract user-specified outputs from the silo component, rename them as needed,
212-
# annotate them with the silo's index, then jam them all into the
213-
# variable-length internal component's input list.
214-
siloed_outputs.update(
215-
{
216-
"{}_{}".format(k, i): v
217-
for k, v in FLScatterGather._extract_outputs(
218-
executed_silo_component.outputs, self.silo_to_aggregation_argument_map
219-
).items()
220-
}
221-
)
222-
223-
# produce internal argument-merging components and record them in local subgraph
224-
merge_comp_mapping = self._inject_merge_components(sg_graph["silo_steps"])
225-
sg_graph["mergers"] = list(merge_comp_mapping.values())
226-
227-
# produce aggregate step inputs by merging static kwargs and mapped arguments from
228-
# internal merge components
229-
agg_inputs = {}
230-
agg_inputs.update(self.aggregation_kwargs)
231-
internal_merge_outputs = {
232-
self._get_aggregator_input_name(k): v.outputs.aggregated_output for k, v in merge_comp_mapping.items()
233-
}
234-
agg_inputs.update(internal_merge_outputs)
235-
236-
# run the user aggregation step
237-
executed_aggregation_component = self.aggregation_component(**agg_inputs)
238-
# Set mode of aggregated mltable inputs as eval mount to allow files referenced within the table
239-
# to be accessible by the component
240-
for name, agg_input in executed_aggregation_component.inputs.items():
241-
if name in self.silo_to_aggregation_argument_map.keys() and agg_input.type == "mltable":
242-
agg_input.mode = "eval_download"
243-
# record aggregation step in subgraph
244-
sg_graph["aggregation"] = executed_aggregation_component
245-
246-
# Anchor both the internal merge components and the user-supplied aggregation step
247-
# to the aggregation compute and datastore
248-
if self.aggregation_compute is not None and self.aggregation_datastore is not None:
249-
# internal merge component is also siloed to wherever the aggregation component lives.
250-
for executed_merge_component in merge_comp_mapping.values():
251-
FLScatterGather._anchor_step(
252-
pipeline_step=executed_merge_component,
253-
compute=self.aggregation_compute,
254-
internal_datastore=self.aggregation_datastore,
255-
orchestrator_datastore=self.aggregation_datastore,
256-
)
257-
FLScatterGather._anchor_step(
258-
pipeline_step=executed_aggregation_component,
259-
compute=self.aggregation_compute,
260-
internal_datastore=self.aggregation_datastore,
261-
orchestrator_datastore=self.aggregation_datastore,
262-
)
263-
return sg_graph
264-
265264
@classmethod
266265
def _get_fl_datastore_path(
267-
cls, datastore_name: str, output_name: str, unique_id: str = "${{name}}", iteration_num: Optional[int] = None
266+
cls,
267+
datastore_name: str,
268+
output_name: str,
269+
unique_id: str = "${{name}}",
270+
iteration_num: Optional[int] = None,
268271
) -> str:
269272
"""Construct a path string using the inputted values. The important aspect is that this produces a
270273
path with a specified datastore.
@@ -434,7 +437,10 @@ def _anchor_step(
434437
if output.type in ANCHORABLE_OUTPUT_TYPES:
435438
validation_result.merge_with(
436439
cls._check_or_set_datastore(
437-
name=name, output=output, target_datastore=orchestrator_datastore, iteration_num=iteration
440+
name=name,
441+
output=output,
442+
target_datastore=orchestrator_datastore,
443+
iteration_num=iteration,
438444
)
439445
)
440446
else:
@@ -463,7 +469,6 @@ def validate_inputs(
463469
max_iterations: int,
464470
raise_error=False,
465471
) -> MutableValidationResult:
466-
467472
validation_result = cls._create_empty_validation_result()
468473

469474
# saved values for validation later on
@@ -474,7 +479,8 @@ def validate_inputs(
474479
# validate silo component
475480
if silo_component is None:
476481
validation_result.append_error(
477-
yaml_path="silo_component", message="silo_component is a required argument for the scatter gather node."
482+
yaml_path="silo_component",
483+
message="silo_component is a required argument for the scatter gather node.",
478484
)
479485
else:
480486
# ensure that silo component has both inputs and outputs
@@ -522,10 +528,14 @@ def validate_inputs(
522528
# validate silos configs
523529
if silo_configs is None:
524530
validation_result.append_error(
525-
yaml_path="silo_configs", message="silo_configs is a required argument for the scatter gather node."
531+
yaml_path="silo_configs",
532+
message="silo_configs is a required argument for the scatter gather node.",
526533
)
527534
elif len(silo_configs) == 0:
528-
validation_result.append_error(yaml_path="silo_configs", message="silo_configs cannot be an empty list.")
535+
validation_result.append_error(
536+
yaml_path="silo_configs",
537+
message="silo_configs cannot be an empty list.",
538+
)
529539
else:
530540
first_silo = silo_configs[0]
531541
expected_inputs = []
@@ -646,7 +656,8 @@ def validate_inputs(
646656

647657
if max_iterations < 1:
648658
validation_result.append_error(
649-
yaml_path="max_iterations", message=f"max_iterations must be a positive value, not '{max_iterations}'."
659+
yaml_path="max_iterations",
660+
message=f"max_iterations must be a positive value, not '{max_iterations}'.",
650661
)
651662

652663
return validation_result.try_raise(
@@ -656,7 +667,11 @@ def validate_inputs(
656667

657668
@classmethod
658669
def _custom_fl_data_path(
659-
cls, datastore_name, output_name, unique_id="${{name}}", iteration_num="${{iteration_num}}"
670+
cls,
671+
datastore_name,
672+
output_name,
673+
unique_id="${{name}}",
674+
iteration_num="${{iteration_num}}",
660675
):
661676
"""Produces a path to store the data during FL training.
662677
Args:
@@ -682,7 +697,11 @@ def _get_aggregator_input_name(self, silo_output_name):
682697

683698
@classmethod
684699
def _try_create_default_mappings(
685-
cls, silo_comp: Component, agg_comp: Component, silo_agg_map: Dict, agg_silo_map: Dict
700+
cls,
701+
silo_comp: Component,
702+
agg_comp: Component,
703+
silo_agg_map: Dict,
704+
agg_silo_map: Dict,
686705
):
687706
"""
688707
This function tries to produce dictionaries that link the silo and aggregation
@@ -739,7 +758,10 @@ def _inject_merge_components(self, executed_silo_components):
739758
executed_component = executed_silo_components[0]
740759

741760
merge_comp_mapping = {}
742-
for silo_output_argument_name, _ in self.silo_to_aggregation_argument_map.items():
761+
for (
762+
silo_output_argument_name,
763+
_,
764+
) in self.silo_to_aggregation_argument_map.items():
743765
merge_comp = self._get_merge_component(executed_component.outputs[silo_output_argument_name].type)
744766
merge_component_inputs = {
745767
silo_output_argument_name

0 commit comments

Comments
 (0)