From 419a8b1c5053437d9b007e8bfe3704e963678c5f Mon Sep 17 00:00:00 2001 From: Victor Li Date: Mon, 20 Oct 2025 14:48:23 -0700 Subject: [PATCH 1/5] readding generic mcmc --- .../compiler/mcmc/generic_mcmc_algorithm.h | 57 +++++++++++++++++++ .../mcmc/generic_mcmc_config.struct.toml | 19 +++++++ .../compiler/mcmc/generic_mcmc_state.h | 28 +++++++++ .../compiler/mcmc/generic_mcmc_algorithm.cc | 1 + .../src/compiler/mcmc/generic_mcmc_state.cc | 12 ++++ .../compiler/mcmc/generic_mcmc_algorithm.cc | 32 +++++++++++ 6 files changed, 149 insertions(+) create mode 100644 lib/compiler/include/compiler/mcmc/generic_mcmc_algorithm.h create mode 100644 lib/compiler/include/compiler/mcmc/generic_mcmc_config.struct.toml create mode 100644 lib/compiler/include/compiler/mcmc/generic_mcmc_state.h create mode 100644 lib/compiler/src/compiler/mcmc/generic_mcmc_algorithm.cc create mode 100644 lib/compiler/src/compiler/mcmc/generic_mcmc_state.cc create mode 100644 lib/compiler/test/src/compiler/mcmc/generic_mcmc_algorithm.cc diff --git a/lib/compiler/include/compiler/mcmc/generic_mcmc_algorithm.h b/lib/compiler/include/compiler/mcmc/generic_mcmc_algorithm.h new file mode 100644 index 0000000000..a27ecbc8f4 --- /dev/null +++ b/lib/compiler/include/compiler/mcmc/generic_mcmc_algorithm.h @@ -0,0 +1,57 @@ +#ifndef _FLEXFLOW_COMPILER_MCMC_GENERIC_MCMC_ALGORITHM_H +#define _FLEXFLOW_COMPILER_MCMC_GENERIC_MCMC_ALGORITHM_H + +#include "compiler/mcmc/generic_mcmc_config.dtg.h" +#include "compiler/mcmc/generic_mcmc_state.h" +#include "utils/nonnegative_int/nonnegative_range.h" +#include "utils/random_utils.h" +#include + +namespace FlexFlow { + +template +void modify_state_for_minimization( + Generic_MCMC_state &best_state, + Generic_MCMC_state ¤t_state, + State candidate, + ScoringFunc scorer, + float temperature) { + float best_estimate = best_state.get_score(); + float new_estimate = scorer(candidate); + float delta = new_estimate - best_estimate; + if (delta < 0 || (randf() < exp(-delta / temperature))) { + current_state = Generic_MCMC_state(candidate, new_estimate); + if (delta < 0) { + best_state = current_state; + } + } +} + +// GeneratingFunc : State -> nn_int -> std::optional +// ScoringFunc : State -> float + +template +Generic_MCMC_state + minimize_score(State const &starting_state, + GeneratingFunc const &generator, + ScoringFunc const &scorer, + GenericMCMCConfig const &search_config) { + using MCMCState = Generic_MCMC_state; + MCMCState best_state = MCMCState(starting_state, scorer(starting_state)); + MCMCState current_state = best_state; + for (nonnegative_int i : nonnegative_range(search_config.num_iterations)) { + std::optional candidate = generator(current_state.get_state(), i); + if (candidate != std::nullopt) { + modify_state_for_minimization(best_state, + current_state, + candidate.value(), + scorer, + search_config.temperature); + } + } + return best_state; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/mcmc/generic_mcmc_config.struct.toml b/lib/compiler/include/compiler/mcmc/generic_mcmc_config.struct.toml new file mode 100644 index 0000000000..e11c84f0bd --- /dev/null +++ b/lib/compiler/include/compiler/mcmc/generic_mcmc_config.struct.toml @@ -0,0 +1,19 @@ +namespace = "FlexFlow" +name = "GenericMCMCConfig" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "utils/nonnegative_int/nonnegative_int.h" +] + +[[fields]] +name = "temperature" +type = "float" + +[[fields]] +name = "num_iterations" +type = "::FlexFlow::nonnegative_int" \ No newline at end of file diff --git a/lib/compiler/include/compiler/mcmc/generic_mcmc_state.h b/lib/compiler/include/compiler/mcmc/generic_mcmc_state.h new file mode 100644 index 0000000000..54e2911bd0 --- /dev/null +++ b/lib/compiler/include/compiler/mcmc/generic_mcmc_state.h @@ -0,0 +1,28 @@ + +#ifndef _FLEXFLOW_COMPILER_MCMC_GENERIC_MCMC_STATE_H +#define _FLEXFLOW_COMPILER_MCMC_GENERIC_MCMC_STATE_H +#include "utils/nonnegative_int/nonnegative_int.h" + +namespace FlexFlow { + +template +struct Generic_MCMC_state { +public: + Generic_MCMC_state(State const &state, Score const &score) + : state(state), score(score) {} + + State const &get_state() const { + return state; + } + Score const &get_score() const { + return score; + } + +private: + State state; + Score score; +}; + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/src/compiler/mcmc/generic_mcmc_algorithm.cc b/lib/compiler/src/compiler/mcmc/generic_mcmc_algorithm.cc new file mode 100644 index 0000000000..1bf4f5c2b7 --- /dev/null +++ b/lib/compiler/src/compiler/mcmc/generic_mcmc_algorithm.cc @@ -0,0 +1 @@ +#include "compiler/mcmc/generic_mcmc_algorithm.h" diff --git a/lib/compiler/src/compiler/mcmc/generic_mcmc_state.cc b/lib/compiler/src/compiler/mcmc/generic_mcmc_state.cc new file mode 100644 index 0000000000..6aa4dd5eff --- /dev/null +++ b/lib/compiler/src/compiler/mcmc/generic_mcmc_state.cc @@ -0,0 +1,12 @@ +#include "compiler/mcmc/generic_mcmc_state.h" +#include "utils/archetypes/ordered_value_type.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { +using State = value_type<0>; +using Score = ordered_value_type<1>; + +template struct Generic_MCMC_state; +template struct Generic_MCMC_state; + +} // namespace FlexFlow diff --git a/lib/compiler/test/src/compiler/mcmc/generic_mcmc_algorithm.cc b/lib/compiler/test/src/compiler/mcmc/generic_mcmc_algorithm.cc new file mode 100644 index 0000000000..ba6faa93c4 --- /dev/null +++ b/lib/compiler/test/src/compiler/mcmc/generic_mcmc_algorithm.cc @@ -0,0 +1,32 @@ +#include "compiler/mcmc/generic_mcmc_algorithm.h" +#include "doctest/doctest.h" + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("generic_mcmc_algorithm") { + float starting_state = 0.1; + auto generating_func = [](float x, + nonnegative_int i) -> std::optional { + float new_x = x + (randf() - 0.5) / (i.unwrap_nonnegative() + 1); + if (new_x < 0) { + return std::nullopt; + } + if (new_x > 1) { + return std::nullopt; + } + return new_x; + }; + auto scoring_func = [](float x) { return (x - 0.5) * (x - 0.5); }; + GenericMCMCConfig config = GenericMCMCConfig{/*temperature=*/1.0, + /*num_iterations=*/10_n}; + Generic_MCMC_state result = + minimize_score(starting_state, generating_func, scoring_func, config); + float answer = result.get_state(); + float error = result.get_score(); + CHECK(answer > 0.49); + CHECK(answer < 0.51); + CHECK(error >= 0); + CHECK(error < 0.01); + } +} From 79a3d192ad4b07f16c02f835923c09170dfd2fd7 Mon Sep 17 00:00:00 2001 From: Victor Li Date: Wed, 22 Oct 2025 02:35:56 -0700 Subject: [PATCH 2/5] readding mcmc --- .../machine_mapping/allowed_machine_views.h | 21 ++ ..._substitution_and_update_machine_mapping.h | 32 +++ .../machine_mapping_mutation_set.h | 19 ++ .../compiler/mcmc/mcmc_over_mapped_pcg.h | 22 ++ .../mcmc_over_mapped_pcg_config.struct.toml | 28 +++ lib/compiler/include/compiler/search_result.h | 13 ++ .../compiler/search_result.struct.toml | 17 ++ .../src/compiler/allowed_machine_views.cc | 2 + ...substitution_and_update_machine_mapping.cc | 197 ++++++++++++++++++ .../machine_mapping_mutation_set.cc | 52 +++++ .../src/compiler/mcmc/mcmc_over_mapped_pcg.cc | 76 +++++++ lib/compiler/src/compiler/search_result.cc | 15 ++ .../src/compiler/mcmc/mcmc_over_mapped_pcg.cc | 94 +++++++++ .../include/substitutions/pcg_pattern.h | 4 + .../substitutions/unity_substitution_set.h | 3 + .../operator_pattern/satisfies_constraint.cc | 28 +++ .../src/substitutions/pcg_pattern.cc | 12 ++ .../tensor_pattern/satisfies_constraint.cc | 10 + .../substitutions/unity_substitution_set.cc | 10 + lib/utils/include/utils/random_utils.h | 2 +- 20 files changed, 656 insertions(+), 1 deletion(-) create mode 100644 lib/compiler/include/compiler/machine_mapping/allowed_machine_views.h create mode 100644 lib/compiler/include/compiler/machine_mapping/apply_substitution_and_update_machine_mapping.h create mode 100644 lib/compiler/include/compiler/machine_mapping/machine_mapping_mutation_set.h create mode 100644 lib/compiler/include/compiler/mcmc/mcmc_over_mapped_pcg.h create mode 100644 lib/compiler/include/compiler/mcmc/mcmc_over_mapped_pcg_config.struct.toml create mode 100644 lib/compiler/include/compiler/search_result.h create mode 100644 lib/compiler/include/compiler/search_result.struct.toml create mode 100644 lib/compiler/src/compiler/machine_mapping/apply_substitution_and_update_machine_mapping.cc create mode 100644 lib/compiler/src/compiler/machine_mapping/machine_mapping_mutation_set.cc create mode 100644 lib/compiler/src/compiler/mcmc/mcmc_over_mapped_pcg.cc create mode 100644 lib/compiler/src/compiler/search_result.cc create mode 100644 lib/compiler/test/src/compiler/mcmc/mcmc_over_mapped_pcg.cc diff --git a/lib/compiler/include/compiler/machine_mapping/allowed_machine_views.h b/lib/compiler/include/compiler/machine_mapping/allowed_machine_views.h new file mode 100644 index 0000000000..9bb73fd1a9 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/allowed_machine_views.h @@ -0,0 +1,21 @@ +#ifndef _FLEXFLOW_COMPILER_ALLOWED_MACHINE_VIEWS_H +#define _FLEXFLOW_COMPILER_ALLOWED_MACHINE_VIEWS_H + +#include "pcg/machine_specification.dtg.h" +#include "pcg/machine_view.dtg.h" +#include "pcg/operator_task_space.dtg.h" + +namespace FlexFlow { + +bool is_valid_machine_view(MachineView const &mv, + OperatorTaskSpace const &task, + MachineSpecification const &ms); + +std::unordered_set + get_allowed_machine_views(MachineSpecification const &machine_spec, + OperatorTaskSpace const &task, + DeviceType device_type); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/machine_mapping/apply_substitution_and_update_machine_mapping.h b/lib/compiler/include/compiler/machine_mapping/apply_substitution_and_update_machine_mapping.h new file mode 100644 index 0000000000..b08ca57851 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/apply_substitution_and_update_machine_mapping.h @@ -0,0 +1,32 @@ +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_APPLY_SUBSTITUTION_APPLY_SUBSTITUTION_AND_UPDATE_MACHINE_MAPPING_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_APPLY_SUBSTITUTION_APPLY_SUBSTITUTION_AND_UPDATE_MACHINE_MAPPING_H + +#include "compiler/search_result.dtg.h" +#include "substitutions/pcg_pattern_match.dtg.h" +#include "substitutions/sub_parallel_computation_graph.dtg.h" +#include "substitutions/substitution.dtg.h" + +namespace FlexFlow { +/** + * @brief Applies \p substitution to \p mapped_pcg at the location specified by + * \p match, returning the resulting SearchResult (mapped pcg) + * + * @param mapped_pcg + * @param substitution + * @param match The location at which to apply substitution. This location in + * sub_pcg should match substitution's PCGPattern. Likely created by running + * FlexFlow::find_pattern_matches(PCGPattern const &, + * SubParallelComputationGraph const &). + * @return SearchResult A mapped pcg similar to mapped_pcg, but with + * the subgraph of the pcg specified by match replaced with the result of the + * output expression of substitution and the machine mapping updated to account + * for the new output + */ +SearchResult apply_substitution_and_update_machine_mapping( + SearchResult const &mapped_pcg, + Substitution const &sub, + PCGPatternMatch const &match); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_mutation_set.h b/lib/compiler/include/compiler/machine_mapping/machine_mapping_mutation_set.h new file mode 100644 index 0000000000..43af640e02 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_mutation_set.h @@ -0,0 +1,19 @@ +#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MCMC_MACHINE_MAPPING_MUTATION_SET_H +#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MCMC_MACHINE_MAPPING_MUTATION_SET_H + +#include "compiler/machine_mapping/machine_mapping.h" +#include "compiler/search_result.dtg.h" + +namespace FlexFlow { +std::optional + get_naive_mapping(ParallelComputationGraph &pcg, + MachineSpecification const &resources, + DeviceType const &device_type); + +std::optional + get_random_mutation(SearchResult mapped_pcg, + MachineSpecification const &resource, + DeviceType const &device_type); +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/mcmc/mcmc_over_mapped_pcg.h b/lib/compiler/include/compiler/mcmc/mcmc_over_mapped_pcg.h new file mode 100644 index 0000000000..bc0d2dfb58 --- /dev/null +++ b/lib/compiler/include/compiler/mcmc/mcmc_over_mapped_pcg.h @@ -0,0 +1,22 @@ +#ifndef _FLEXFLOW_COMPILER_MCMC_OVER_MAPPED_PCG_H +#define _FLEXFLOW_COMPILER_MCMC_OVER_MAPPED_PCG_H + +#include "compiler/cost_estimator/runtime_only_cost_estimator.h" +#include "compiler/mcmc/mcmc_over_mapped_pcg_config.dtg.h" +#include "compiler/search_result.dtg.h" +#include "pcg/computation_graph.h" +#include "pcg/machine_specification.dtg.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph.dtg.h" +#include "substitutions/sub_parallel_computation_graph.h" +#include "substitutions/substitution.h" + +namespace FlexFlow { + +SearchResult mcmc_graph_optimize(ParallelComputationGraph &pcg, + RuntimeOnlyCostEstimator const &cost_estimator, + MachineSpecification const &resources, + MCMCOverMappedPCGConfig const &search_config); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/mcmc/mcmc_over_mapped_pcg_config.struct.toml b/lib/compiler/include/compiler/mcmc/mcmc_over_mapped_pcg_config.struct.toml new file mode 100644 index 0000000000..e1548a581e --- /dev/null +++ b/lib/compiler/include/compiler/mcmc/mcmc_over_mapped_pcg_config.struct.toml @@ -0,0 +1,28 @@ +namespace = "FlexFlow" +name = "MCMCOverMappedPCGConfig" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "pcg/device_type.dtg.h", + "utils/nonnegative_int/nonnegative_int.h" +] + +[[fields]] +name = "temperature" +type = "float" + +[[fields]] +name = "num_iterations" +type = "::FlexFlow::nonnegative_int" + +[[fields]] +name = "substitution_interval" +type = "::FlexFlow::nonnegative_int" + +[[fields]] +name = "device_type" +type = "::FlexFlow::DeviceType" \ No newline at end of file diff --git a/lib/compiler/include/compiler/search_result.h b/lib/compiler/include/compiler/search_result.h new file mode 100644 index 0000000000..197b36e9ea --- /dev/null +++ b/lib/compiler/include/compiler/search_result.h @@ -0,0 +1,13 @@ +#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_GRAPH_OPTIMIZE_RESULT_H +#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_GRAPH_OPTIMIZE_RESULT_H + +#include "compiler/search_result.dtg.h" + +namespace FlexFlow { + +std::string format_as(SearchResult const &); +std::ostream &operator<<(std::ostream &, SearchResult const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/search_result.struct.toml b/lib/compiler/include/compiler/search_result.struct.toml new file mode 100644 index 0000000000..7e7e59d7c9 --- /dev/null +++ b/lib/compiler/include/compiler/search_result.struct.toml @@ -0,0 +1,17 @@ +namespace = "FlexFlow" +name = "SearchResult" +features = [ +] + +includes = [ + "pcg/parallel_computation_graph/parallel_computation_graph.h", + "compiler/machine_mapping/machine_mapping.h", +] + +[[fields]] +name = "pcg" +type = "::FlexFlow::ParallelComputationGraph" + +[[fields]] +name = "machine_mapping" +type = "::FlexFlow::MachineMapping" \ No newline at end of file diff --git a/lib/compiler/src/compiler/allowed_machine_views.cc b/lib/compiler/src/compiler/allowed_machine_views.cc index 370cb5a4ec..64b910bf7d 100644 --- a/lib/compiler/src/compiler/allowed_machine_views.cc +++ b/lib/compiler/src/compiler/allowed_machine_views.cc @@ -57,6 +57,8 @@ static std::unordered_set product(transform(tensor_dims, [](positive_int num_devices) { return nonnegative_int{num_devices.int_from_positive_int() - 1}; })); + min_num_devices_with_full_stride_volume = + std::max(min_num_devices_with_full_stride_volume, 1_n); return ceildiv(total_devices, positive_int{min_num_devices_with_full_stride_volume}); }; diff --git a/lib/compiler/src/compiler/machine_mapping/apply_substitution_and_update_machine_mapping.cc b/lib/compiler/src/compiler/machine_mapping/apply_substitution_and_update_machine_mapping.cc new file mode 100644 index 0000000000..252384985b --- /dev/null +++ b/lib/compiler/src/compiler/machine_mapping/apply_substitution_and_update_machine_mapping.cc @@ -0,0 +1,197 @@ +#include "compiler/machine_mapping/apply_substitution_and_update_machine_mapping.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph_edge.h" +#include "pcg/parallel_computation_graph/parallel_tensor_guid_t.h" +#include "substitutions/apply_substitution/apply_substitution.h" +#include "substitutions/apply_substitution/evaluate_substitution_output.h" +#include "substitutions/apply_substitution/output_expr_to_result_sub_pcg_mapping.h" +#include "substitutions/open_parallel_tensor_guid_t.h" +#include "substitutions/pcg_pattern_match.h" +#include "substitutions/sub_parallel_computation_graph.h" +#include "substitutions/sub_parallel_computation_graph_data.dtg.h" +#include "substitutions/sub_parallel_computation_graph_edge.h" +#include "utils/containers/is_subseteq_of.h" +#include "utils/containers/keys.h" +#include "utils/containers/merge_maps.h" +#include "utils/containers/restrict_keys.h" +#include "utils/containers/set_minus.h" +#include "utils/containers/values.h" + +namespace FlexFlow { + +SearchResult apply_substitution_and_update_machine_mapping( + SearchResult const &mapped_pcg, + Substitution const &sub, + PCGPatternMatch const &match) { + SubParallelComputationGraph spcg = sub_pcg_from_full_pcg(mapped_pcg.pcg); + + auto substitution_output_result = + evaluate_substitution_output(spcg, sub, match); + SubParallelComputationGraph substitution_output_graph = + substitution_output_result.first; + OutputExprToResultSubPCGMapping output_expr_to_result_sub_pcg_mapping = + substitution_output_result.second; + + SubParallelComputationGraphData output_graph_data = + get_sub_pcg_data(substitution_output_graph); + SubParallelComputationGraphData pre_data = get_sub_pcg_data(spcg); + + std::unordered_set pre_nodes = + keys(pre_data.node_data); + std::unordered_set matched_nodes = + unordered_set_of(values(match.node_assignment)); + std::unordered_set post_nodes_from_original_graph = + set_minus(pre_nodes, matched_nodes); + + std::unordered_map machine_views = + mapped_pcg.machine_mapping.machine_views; + + std::unordered_set substituted_machine_views = + transform(matched_nodes, [&](parallel_layer_guid_t const &node) { + return machine_views.at(node); + }); + MachineView first_substituted_machine_view = + *substituted_machine_views.begin(); + + std::unordered_map post_node_data = + [&] { + std::unordered_map + post_node_data_from_orig = restrict_keys( + pre_data.node_data, post_nodes_from_original_graph); + std::unordered_map + post_node_data_from_sub = output_graph_data.node_data; + + for (auto [layer, attrs] : post_node_data_from_sub) { + machine_views.insert_or_assign(layer, first_substituted_machine_view); + } + + return merge_disjoint_maps(post_node_data_from_orig, + post_node_data_from_sub); + }(); + + std::unordered_set post_edges = [&] { + std::unordered_set post_edges_from_orig = + filter(pre_data.edges, [&](SubParallelComputationGraphEdge const &e) { + if (e.raw_edge.has()) { + return true; + } else { + DataflowEdge dfe = e.raw_edge.get(); + parallel_layer_guid_t src = parallel_layer_guid_t{dfe.src.node}; + parallel_layer_guid_t dst = parallel_layer_guid_t{dfe.dst.node}; + return !(contains(matched_nodes, src) || + contains(matched_nodes, dst)); + } + }); + + std::unordered_set post_edges_from_sub = + filter(output_graph_data.edges, + [&](SubParallelComputationGraphEdge const &e) { + return !e.raw_edge.has(); + }); + + bidict + output_orig_pattern_mapping = get_output_mapping_for_pcg_pattern_match( + match, sub.pcg_pattern, spcg); + bidict + output_post_outexpr_mapping = get_output_graph_expr_output_mapping( + output_expr_to_result_sub_pcg_mapping, + sub.output_graph_expr, + substitution_output_graph); + + std::unordered_set incoming_to_sub_edges; + for (auto const &[pattern_input, base_graph_tensor] : + match.input_assignment) { + OutputGraphExprInput output_expr_input = + sub.inputs_mapping.at_l(pattern_input); + input_parallel_tensor_guid_t output_graph_input = + output_expr_to_result_sub_pcg_mapping.input_mapping.at_r( + output_expr_input); + std::unordered_set uses = get_parallel_tensor_uses( + substitution_output_graph, + open_parallel_tensor_guid_from_input(output_graph_input)); + for (parallel_tensor_use_t const &use : uses) { + SubParallelComputationGraphEdge new_edge = + subpcg_edge_from_tensor_and_use(base_graph_tensor, use); + incoming_to_sub_edges.insert(new_edge); + } + } + + std::unordered_set outgoing_from_sub_edges; + for (ParallelComputationGraphEdge const &outgoing_edge : + get_subgraph_outgoing_edges(spcg, matched_nodes)) { + parallel_tensor_guid_t original_tensor = + get_parallel_tensor(outgoing_edge); + PatternNodeOutput pattern_tensor = + output_orig_pattern_mapping.at_r(original_tensor); + OutputGraphExprNodeOutput output_graph_tensor = + sub.outputs_mapping.at_l(pattern_tensor); + parallel_tensor_guid_t new_tensor = + output_post_outexpr_mapping.at_r(output_graph_tensor); + + SubParallelComputationGraphEdge new_edge = + subpcg_edge_from_tensor_and_dst( + new_tensor, + get_dst_layer(outgoing_edge), + get_dst_layer_input_idx(outgoing_edge)); + outgoing_from_sub_edges.insert(new_edge); + } + + return set_union(std::vector{ + post_edges_from_orig, + post_edges_from_sub, + incoming_to_sub_edges, + outgoing_from_sub_edges, + }); + }(); + + std::unordered_set post_inputs = + pre_data.inputs; + + std::unordered_map + post_value_data = [&] { + std::unordered_map + post_value_data_from_orig = filter_keys( + pre_data.value_data, [&](open_parallel_tensor_guid_t const &t) { + return visit_open_parallel_tensor_guid( + t, + overload{ + [&](parallel_tensor_guid_t const &t) { + return contains(post_nodes_from_original_graph, + get_source_layer(t)); + }, + [](input_parallel_tensor_guid_t const &) { + return true; + }, + }); + }); + + std::unordered_map + post_value_data_from_sub = output_graph_data.value_data; + return merge_disjoint_maps(post_value_data_from_orig, + post_value_data_from_sub); + }(); + + SubParallelComputationGraphData post_data = SubParallelComputationGraphData{ + post_node_data, + post_edges, + post_inputs, + post_value_data, + }; + + assert(is_subseteq_of(keys(post_node_data), keys(machine_views))); + + for (auto it = machine_views.begin(); it != machine_views.end();) { + if (post_node_data.find(it->first) == post_node_data.end()) { + it = machine_views.erase(it); + } else { + ++it; + } + } + + assert(keys(post_node_data) == keys(machine_views)); + + return SearchResult{ + pcg_from_sub_pcg_by_dropping_inputs(sub_pcg_from_graph_data(post_data)), + MachineMapping{machine_views}}; +} + +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/machine_mapping/machine_mapping_mutation_set.cc b/lib/compiler/src/compiler/machine_mapping/machine_mapping_mutation_set.cc new file mode 100644 index 0000000000..6b20c9963d --- /dev/null +++ b/lib/compiler/src/compiler/machine_mapping/machine_mapping_mutation_set.cc @@ -0,0 +1,52 @@ +#include "compiler/machine_mapping/machine_mapping_mutation_set.h" +#include "compiler/allowed_machine_views.h" +#include "pcg/machine_view.h" +#include "pcg/operator_task_space.h" +#include "utils/containers/vector_of.h" +#include "utils/nonnegative_int/nonnegative_range.h" +#include "utils/random_utils.h" +#include "utils/vector.h" + +namespace FlexFlow { + +std::optional + get_naive_mapping(ParallelComputationGraph &pcg, + MachineSpecification const &resources, + DeviceType const &device_type) { + std::vector layers = topological_ordering(pcg); + std::unordered_map machine_views; + for (parallel_layer_guid_t layer : layers) { + OperatorTaskSpace task = get_operator_task_space(pcg, layer); + std::unordered_set allowed_machine_views = + get_allowed_machine_views(resources, task, DeviceType::GPU); + if (allowed_machine_views.empty()) { + return std::nullopt; + } + machine_views.insert({layer, *(allowed_machine_views.begin())}); + } + return MachineMapping{machine_views}; +} + +std::optional + get_random_mutation(SearchResult mapped_pcg, + MachineSpecification const &resources, + DeviceType const &device_type) { + ParallelComputationGraph pcg = mapped_pcg.pcg; + std::vector layers = topological_ordering(pcg); + if (layers.size() == 0) { + return std::nullopt; + } + parallel_layer_guid_t random_layer = select_random(layers); + + MachineMapping machine_mapping = mapped_pcg.machine_mapping; + MachineView machine_view = machine_mapping.machine_views.at(random_layer); + OperatorTaskSpace task = get_operator_task_space(pcg, random_layer); + + std::vector allowed_machine_views = + vector_of(get_allowed_machine_views(resources, task, device_type)); + MachineView random_new_machine_view = select_random(allowed_machine_views); + + machine_mapping.machine_views.at(random_layer) = random_new_machine_view; + return machine_mapping; +} +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/mcmc/mcmc_over_mapped_pcg.cc b/lib/compiler/src/compiler/mcmc/mcmc_over_mapped_pcg.cc new file mode 100644 index 0000000000..46f26eab2d --- /dev/null +++ b/lib/compiler/src/compiler/mcmc/mcmc_over_mapped_pcg.cc @@ -0,0 +1,76 @@ +#include "compiler/mcmc/mcmc_over_mapped_pcg.h" +#include "compiler/machine_mapping/apply_substitution_and_update_machine_mapping.h" +#include "compiler/machine_mapping/machine_mapping_mutation_set.h" +#include "compiler/mcmc/generic_mcmc_algorithm.h" +#include "compiler/search_result.h" +#include "compiler/task_graph_simulator/task_simulator.h" +#include "substitutions/pcg_pattern.h" +#include "substitutions/pcg_pattern_match.h" +#include "substitutions/unity_substitution_set.h" +#include "utils/optional.h" + +namespace FlexFlow { + +SearchResult mcmc_graph_optimize(ParallelComputationGraph &pcg, + RuntimeOnlyCostEstimator const &cost_estimator, + MachineSpecification const &resources, + MCMCOverMappedPCGConfig const &search_config) { + + std::vector substitutions = get_substitution_set(resources); + + std::optional naive_mapping = + get_naive_mapping(pcg, resources, search_config.device_type); + if (naive_mapping == std::nullopt) { + throw std::runtime_error("Failed to find any solutions"); + } + + SearchResult starting_state = SearchResult{pcg, naive_mapping.value()}; + + auto generating_func = [&](SearchResult mapped_pcg, + nonnegative_int i) -> std::optional { + if (i.unwrap_nonnegative() % + search_config.substitution_interval.unwrap_nonnegative() == + 0) { + // substitutions every (substitution_interval) iterations + std::optional random_substitution = + get_random_substitution(resources); + if (random_substitution != std::nullopt) { + std::optional pattern_match = + get_random_pattern_match(random_substitution.value().pcg_pattern, + sub_pcg_from_full_pcg(mapped_pcg.pcg)); + if (pattern_match != std::nullopt) { + return apply_substitution_and_update_machine_mapping( + mapped_pcg, random_substitution.value(), pattern_match.value()); + } + } + return std::nullopt; + } else { + // machine mapping mutations otherwise + std::optional new_machine_mapping = + get_random_mutation(mapped_pcg, resources, search_config.device_type); + if (new_machine_mapping == std::nullopt) { + return std::nullopt; + } + return SearchResult{mapped_pcg.pcg, new_machine_mapping.value()}; + } + }; + + auto scoring_func = [&](SearchResult mapped_pcg) -> float { + return task_simulator_estimate_forward_pass_time(mapped_pcg.pcg, + cost_estimator, + mapped_pcg.machine_mapping, + resources) + .unwrap_milliseconds(); + }; + + GenericMCMCConfig config = + GenericMCMCConfig{/*temperature*/ search_config.temperature, + /*num_iterations*/ search_config.num_iterations}; + + Generic_MCMC_state result = + minimize_score(starting_state, generating_func, scoring_func, config); + + return result.get_state(); +} + +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/search_result.cc b/lib/compiler/src/compiler/search_result.cc new file mode 100644 index 0000000000..0afc10723a --- /dev/null +++ b/lib/compiler/src/compiler/search_result.cc @@ -0,0 +1,15 @@ +#include "compiler/search_result.h" + +namespace FlexFlow { + +std::string format_as(SearchResult const &r) { + return fmt::format("", + as_dot(r.pcg), + r.machine_mapping); +} + +std::ostream &operator<<(std::ostream &s, SearchResult const &r) { + return (s << fmt::to_string(r)); +} + +} // namespace FlexFlow diff --git a/lib/compiler/test/src/compiler/mcmc/mcmc_over_mapped_pcg.cc b/lib/compiler/test/src/compiler/mcmc/mcmc_over_mapped_pcg.cc new file mode 100644 index 0000000000..87cffa869f --- /dev/null +++ b/lib/compiler/test/src/compiler/mcmc/mcmc_over_mapped_pcg.cc @@ -0,0 +1,94 @@ +#include "compiler/mcmc/mcmc_over_mapped_pcg.h" +#include "compiler/task_graph_simulator/task_simulator.h" +#include "doctest/doctest.h" +#include "internal/runtime_only_cost_estimator_for_test.h" +#include "op-attrs/parallel_tensor_dims.h" +#include "op-attrs/parallel_tensor_shape.dtg.h" +#include "op-attrs/replica_type.dtg.h" +#include "op-attrs/shard_parallel_dim.h" +#include "pcg/computation_graph_builder.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph_builder.h" +#include "pcg/pcg_from_computation_graph.h" +#include "utils/integer_conversions.h" + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("mcmc_graph_optimize") { + ComputationGraph cg = [&] { + ComputationGraphBuilder b; + TensorShape input_tensor_shape = TensorShape{ + TensorDims{ + FFOrdered{32_p, 64_p}, + }, + DataType::FLOAT, + }; + tensor_guid_t t = b.create_input(input_tensor_shape, CreateGrad::YES); + t = b.dense(t, + /*outDim=*/16_p, + /*activation=*/std::nullopt); + t = b.gelu(t); + t = b.dense(t, + /*outDim=*/12_p, + /*activation=*/std::nullopt, + /*use_bias=*/false, + /*data_type=*/DataType::FLOAT, + /*kernel_initializer=*/std::nullopt, + /*bias_initializer=*/std::nullopt); + t = b.relu(t); + t = b.dense(t, + /*outDim=*/8_p, + /*activation=*/Activation::RELU); + return b.computation_graph; + }(); + + ParallelComputationGraph pcg = pcg_from_computation_graph(cg); + + RuntimeOnlyCostEstimator cost_estimator = + make_fake_constant_runtime_only_cost_estimator( + /*forward_op_cost=*/10_ms, + /*backward_op_cost=*/10_ms, + /*comm_cost=*/1_ms); + + MachineSpecification full_machine_spec = MachineSpecification{ + /*num_nodes=*/2_p, + /*num_cpus_per_node=*/1_p, + /*num_gpus_per_node=*/1_p, + /*inter_node_bandwidth=*/1, + /*intra_node_bandwidth=*/1, + }; + + MCMCOverMappedPCGConfig no_search = + MCMCOverMappedPCGConfig{/*temperature=*/1.0, + /*num_iterations=*/1_n, + /*substitution_interval=*/5_n, + /*device_type=*/DeviceType::GPU}; + + SearchResult base_result = + mcmc_graph_optimize(pcg, cost_estimator, full_machine_spec, no_search); + float base_runtime = + task_simulator_estimate_forward_pass_time(base_result.pcg, + cost_estimator, + base_result.machine_mapping, + full_machine_spec) + .unwrap_milliseconds(); + + MCMCOverMappedPCGConfig search_config = + MCMCOverMappedPCGConfig{/*temperature=*/1.0, + /*num_iterations=*/100_n, + /*substitution_interval=*/5_n, + /*device_type=*/DeviceType::GPU}; + + SearchResult result = mcmc_graph_optimize( + pcg, cost_estimator, full_machine_spec, search_config); + float runtime = + task_simulator_estimate_forward_pass_time(result.pcg, + cost_estimator, + result.machine_mapping, + full_machine_spec) + .unwrap_milliseconds(); + + CHECK(runtime < base_runtime); + CHECK(runtime < 100); + } +} diff --git a/lib/substitutions/include/substitutions/pcg_pattern.h b/lib/substitutions/include/substitutions/pcg_pattern.h index f0962b15c2..5005a0b51c 100644 --- a/lib/substitutions/include/substitutions/pcg_pattern.h +++ b/lib/substitutions/include/substitutions/pcg_pattern.h @@ -12,6 +12,10 @@ namespace FlexFlow { std::unordered_set get_nodes(PCGPattern const &); +std::optional + get_random_pattern_match(PCGPattern const &pattern, + SubParallelComputationGraph const &pcg); + /** * @brief Find all locations in \p pcg that match \p pattern */ diff --git a/lib/substitutions/include/substitutions/unity_substitution_set.h b/lib/substitutions/include/substitutions/unity_substitution_set.h index 183f76ac8a..574dd9da3d 100644 --- a/lib/substitutions/include/substitutions/unity_substitution_set.h +++ b/lib/substitutions/include/substitutions/unity_substitution_set.h @@ -7,6 +7,9 @@ namespace FlexFlow { +std::optional + get_random_substitution(MachineSpecification const &resources); + std::vector get_substitution_set(MachineSpecification const &resources); diff --git a/lib/substitutions/src/substitutions/operator_pattern/satisfies_constraint.cc b/lib/substitutions/src/substitutions/operator_pattern/satisfies_constraint.cc index 194ae49255..4f61eac113 100644 --- a/lib/substitutions/src/substitutions/operator_pattern/satisfies_constraint.cc +++ b/lib/substitutions/src/substitutions/operator_pattern/satisfies_constraint.cc @@ -16,6 +16,34 @@ bool operator_satisfies_constraint( switch (constraint.constraint_type) { case ConstraintType::EQUAL: return expr_val.value() == constraint.attribute_value; + case ConstraintType::DIVISIBLE_BY: { + auto get_nonnegative_int_if_possible = + [](OperatorAttributeValue v) -> std::optional { + if (v.has()) { + return v.get(); + } + if (v.has()) { + return v.get().nonnegative_int_from_positive_int(); + } + return std::nullopt; + }; + + if (!expr_val.has_value()) { + throw mk_runtime_error("DIVISIBLE_BY constraint requires " + "nonnegative_int or positive_int values"); + } + + std::optional maybe_expr_val_nn = + get_nonnegative_int_if_possible(expr_val.value()); + std::optional maybe_attr_val_nn = + get_nonnegative_int_if_possible(constraint.attribute_value); + + if (maybe_expr_val_nn.has_value() && maybe_attr_val_nn.has_value()) { + return maybe_expr_val_nn.value() % maybe_attr_val_nn.value() == 0; + } + throw mk_runtime_error("DIVISIBLE_BY constraint requires nonnegative_int " + "or positive_int values"); + } default: throw mk_runtime_error( fmt::format("Unknown constraint type {}", diff --git a/lib/substitutions/src/substitutions/pcg_pattern.cc b/lib/substitutions/src/substitutions/pcg_pattern.cc index a0af875848..1e260f9fe3 100644 --- a/lib/substitutions/src/substitutions/pcg_pattern.cc +++ b/lib/substitutions/src/substitutions/pcg_pattern.cc @@ -11,6 +11,7 @@ #include "utils/graph/node/algorithms.h" #include "utils/graph/open_dataflow_graph/algorithms/get_inputs.h" #include "utils/graph/open_dataflow_graph/algorithms/get_open_dataflow_graph_inputs.h" +#include "utils/random_utils.h" namespace FlexFlow { @@ -20,6 +21,17 @@ std::unordered_set get_nodes(PCGPattern const &p) { return transform(raw_nodes, [](Node const &n) { return PatternNode{n}; }); } +std::optional + get_random_pattern_match(PCGPattern const &pattern, + SubParallelComputationGraph const &pcg) { + std::vector pattern_matches = + find_pattern_matches(pattern, pcg); + if (pattern_matches.empty()) { + return std::nullopt; + } + return select_random(pattern_matches); +} + static MatchAdditionalCriterion pcg_pattern_criteria(PCGPattern const &pattern, SubParallelComputationGraph const &pcg) { diff --git a/lib/substitutions/src/substitutions/tensor_pattern/satisfies_constraint.cc b/lib/substitutions/src/substitutions/tensor_pattern/satisfies_constraint.cc index 974bfcabc0..cc0af12c91 100644 --- a/lib/substitutions/src/substitutions/tensor_pattern/satisfies_constraint.cc +++ b/lib/substitutions/src/substitutions/tensor_pattern/satisfies_constraint.cc @@ -12,6 +12,16 @@ bool parallel_tensor_satisfies_constraint( switch (constraint.constraint_type) { case ConstraintType::EQUAL: return expr_val == constraint.attribute_value; + case ConstraintType::DIVISIBLE_BY: { + if (expr_val.has() && + constraint.attribute_value.has()) { + return expr_val.get() % + constraint.attribute_value.get() == + 0; + } + throw mk_runtime_error( + "DIVISIBLE_BY constraint requires nonnegative_int values"); + } default: throw mk_runtime_error( fmt::format("Unknown constraint type {}", diff --git a/lib/substitutions/src/substitutions/unity_substitution_set.cc b/lib/substitutions/src/substitutions/unity_substitution_set.cc index 4b00cdd95f..c8d9266978 100644 --- a/lib/substitutions/src/substitutions/unity_substitution_set.cc +++ b/lib/substitutions/src/substitutions/unity_substitution_set.cc @@ -7,9 +7,19 @@ #include "utils/containers/get_only.h" #include "utils/nonnegative_int/nonnegative_int.h" #include "utils/nonnegative_int/nonnegative_range.h" +#include "utils/random_utils.h" namespace FlexFlow { +std::optional + get_random_substitution(MachineSpecification const &resources) { + std::vector substitutions = get_substitution_set(resources); + if (substitutions.empty()) { + return std::nullopt; + } + return select_random(substitutions); +} + std::vector get_substitution_set(MachineSpecification const &resources) { std::vector substitutions; diff --git a/lib/utils/include/utils/random_utils.h b/lib/utils/include/utils/random_utils.h index 99da9646a1..014c38fc51 100644 --- a/lib/utils/include/utils/random_utils.h +++ b/lib/utils/include/utils/random_utils.h @@ -5,7 +5,7 @@ #include #include -float randf() { +inline float randf() { return static_cast(std::rand()) / static_cast(RAND_MAX); } From 6715e9c6c271dcec832b7498f20d687d6c66fc93 Mon Sep 17 00:00:00 2001 From: Victor Li Date: Sun, 30 Nov 2025 19:48:34 -0800 Subject: [PATCH 3/5] removing generic_mcmc_state --- .../compiler/mcmc/generic_mcmc_algorithm.h | 55 +++++++------------ .../compiler/mcmc/generic_mcmc_state.h | 28 ---------- .../src/compiler/mcmc/generic_mcmc_state.cc | 12 ---- .../src/compiler/mcmc/mcmc_over_mapped_pcg.cc | 4 +- .../compiler/mcmc/generic_mcmc_algorithm.cc | 13 ++--- 5 files changed, 27 insertions(+), 85 deletions(-) delete mode 100644 lib/compiler/include/compiler/mcmc/generic_mcmc_state.h delete mode 100644 lib/compiler/src/compiler/mcmc/generic_mcmc_state.cc diff --git a/lib/compiler/include/compiler/mcmc/generic_mcmc_algorithm.h b/lib/compiler/include/compiler/mcmc/generic_mcmc_algorithm.h index a27ecbc8f4..ecba9957d8 100644 --- a/lib/compiler/include/compiler/mcmc/generic_mcmc_algorithm.h +++ b/lib/compiler/include/compiler/mcmc/generic_mcmc_algorithm.h @@ -2,52 +2,35 @@ #define _FLEXFLOW_COMPILER_MCMC_GENERIC_MCMC_ALGORITHM_H #include "compiler/mcmc/generic_mcmc_config.dtg.h" -#include "compiler/mcmc/generic_mcmc_state.h" #include "utils/nonnegative_int/nonnegative_range.h" +#include "utils/optional.h" +#include "utils/containers/transform.h" #include "utils/random_utils.h" -#include namespace FlexFlow { -template -void modify_state_for_minimization( - Generic_MCMC_state &best_state, - Generic_MCMC_state ¤t_state, - State candidate, - ScoringFunc scorer, - float temperature) { - float best_estimate = best_state.get_score(); - float new_estimate = scorer(candidate); - float delta = new_estimate - best_estimate; - if (delta < 0 || (randf() < exp(-delta / temperature))) { - current_state = Generic_MCMC_state(candidate, new_estimate); - if (delta < 0) { - best_state = current_state; - } - } -} - // GeneratingFunc : State -> nn_int -> std::optional // ScoringFunc : State -> float template -Generic_MCMC_state - minimize_score(State const &starting_state, - GeneratingFunc const &generator, - ScoringFunc const &scorer, - GenericMCMCConfig const &search_config) { - using MCMCState = Generic_MCMC_state; - MCMCState best_state = MCMCState(starting_state, scorer(starting_state)); - MCMCState current_state = best_state; +State minimize_score(State const &starting_state, + GeneratingFunc const &generator, + ScoringFunc const &scorer, + GenericMCMCConfig const &search_config) { + State best_state = starting_state; + State current_state = best_state; for (nonnegative_int i : nonnegative_range(search_config.num_iterations)) { - std::optional candidate = generator(current_state.get_state(), i); - if (candidate != std::nullopt) { - modify_state_for_minimization(best_state, - current_state, - candidate.value(), - scorer, - search_config.temperature); - } + std::optional maybe_new_state = transform(generator(current_state, i), [&](State const &s) { + float delta = scorer(s) - scorer(best_state); + if (randf() < exp(-delta / search_config.temperature)) { + if (delta < 0) { + best_state = s; + } + return s; + } + return current_state; + }); + current_state = or_else(maybe_new_state, [&]() {return current_state;}); } return best_state; } diff --git a/lib/compiler/include/compiler/mcmc/generic_mcmc_state.h b/lib/compiler/include/compiler/mcmc/generic_mcmc_state.h deleted file mode 100644 index 54e2911bd0..0000000000 --- a/lib/compiler/include/compiler/mcmc/generic_mcmc_state.h +++ /dev/null @@ -1,28 +0,0 @@ - -#ifndef _FLEXFLOW_COMPILER_MCMC_GENERIC_MCMC_STATE_H -#define _FLEXFLOW_COMPILER_MCMC_GENERIC_MCMC_STATE_H -#include "utils/nonnegative_int/nonnegative_int.h" - -namespace FlexFlow { - -template -struct Generic_MCMC_state { -public: - Generic_MCMC_state(State const &state, Score const &score) - : state(state), score(score) {} - - State const &get_state() const { - return state; - } - Score const &get_score() const { - return score; - } - -private: - State state; - Score score; -}; - -} // namespace FlexFlow - -#endif diff --git a/lib/compiler/src/compiler/mcmc/generic_mcmc_state.cc b/lib/compiler/src/compiler/mcmc/generic_mcmc_state.cc deleted file mode 100644 index 6aa4dd5eff..0000000000 --- a/lib/compiler/src/compiler/mcmc/generic_mcmc_state.cc +++ /dev/null @@ -1,12 +0,0 @@ -#include "compiler/mcmc/generic_mcmc_state.h" -#include "utils/archetypes/ordered_value_type.h" -#include "utils/archetypes/value_type.h" - -namespace FlexFlow { -using State = value_type<0>; -using Score = ordered_value_type<1>; - -template struct Generic_MCMC_state; -template struct Generic_MCMC_state; - -} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/mcmc/mcmc_over_mapped_pcg.cc b/lib/compiler/src/compiler/mcmc/mcmc_over_mapped_pcg.cc index 46f26eab2d..a5bec534ab 100644 --- a/lib/compiler/src/compiler/mcmc/mcmc_over_mapped_pcg.cc +++ b/lib/compiler/src/compiler/mcmc/mcmc_over_mapped_pcg.cc @@ -67,10 +67,10 @@ SearchResult mcmc_graph_optimize(ParallelComputationGraph &pcg, GenericMCMCConfig{/*temperature*/ search_config.temperature, /*num_iterations*/ search_config.num_iterations}; - Generic_MCMC_state result = + SearchResult result = minimize_score(starting_state, generating_func, scoring_func, config); - return result.get_state(); + return result; } } // namespace FlexFlow diff --git a/lib/compiler/test/src/compiler/mcmc/generic_mcmc_algorithm.cc b/lib/compiler/test/src/compiler/mcmc/generic_mcmc_algorithm.cc index ba6faa93c4..b020006f4d 100644 --- a/lib/compiler/test/src/compiler/mcmc/generic_mcmc_algorithm.cc +++ b/lib/compiler/test/src/compiler/mcmc/generic_mcmc_algorithm.cc @@ -19,14 +19,13 @@ TEST_SUITE(FF_TEST_SUITE) { }; auto scoring_func = [](float x) { return (x - 0.5) * (x - 0.5); }; GenericMCMCConfig config = GenericMCMCConfig{/*temperature=*/1.0, - /*num_iterations=*/10_n}; - Generic_MCMC_state result = + /*num_iterations=*/50_n}; + float answer = minimize_score(starting_state, generating_func, scoring_func, config); - float answer = result.get_state(); - float error = result.get_score(); - CHECK(answer > 0.49); - CHECK(answer < 0.51); + float error = scoring_func(answer); + CHECK(answer > 0.47); + CHECK(answer < 0.53); CHECK(error >= 0); - CHECK(error < 0.01); + CHECK(error < 0.001); } } From 030cd254e42038ebd0cd611dc7082841943e0521 Mon Sep 17 00:00:00 2001 From: Victor Li Date: Mon, 1 Dec 2025 20:32:12 -0800 Subject: [PATCH 4/5] MCMC modifications --- .../machine_mapping_mutation_set.h | 6 +- .../compiler/mcmc/generic_mcmc_algorithm.h | 39 +++++----- .../compiler/mcmc/mcmc_over_mapped_pcg.h | 9 +-- .../mcmc_over_mapped_pcg_config.struct.toml | 4 +- .../machine_mapping_mutation_set.cc | 10 +-- .../compiler/mcmc/generic_mcmc_algorithm.cc | 14 ++++ .../src/compiler/mcmc/mcmc_over_mapped_pcg.cc | 71 ++++++++----------- .../compiler/mcmc/generic_mcmc_algorithm.cc | 14 ++-- .../src/compiler/mcmc/mcmc_over_mapped_pcg.cc | 10 +-- lib/utils/include/utils/optional.h | 3 +- 10 files changed, 93 insertions(+), 87 deletions(-) diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_mutation_set.h b/lib/compiler/include/compiler/machine_mapping/machine_mapping_mutation_set.h index 43af640e02..16385a74e8 100644 --- a/lib/compiler/include/compiler/machine_mapping/machine_mapping_mutation_set.h +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_mutation_set.h @@ -6,9 +6,9 @@ namespace FlexFlow { std::optional - get_naive_mapping(ParallelComputationGraph &pcg, - MachineSpecification const &resources, - DeviceType const &device_type); + get_random_mapping(ParallelComputationGraph &pcg, + MachineSpecification const &resources, + DeviceType const &device_type); std::optional get_random_mutation(SearchResult mapped_pcg, diff --git a/lib/compiler/include/compiler/mcmc/generic_mcmc_algorithm.h b/lib/compiler/include/compiler/mcmc/generic_mcmc_algorithm.h index ecba9957d8..a3baa251e3 100644 --- a/lib/compiler/include/compiler/mcmc/generic_mcmc_algorithm.h +++ b/lib/compiler/include/compiler/mcmc/generic_mcmc_algorithm.h @@ -2,35 +2,36 @@ #define _FLEXFLOW_COMPILER_MCMC_GENERIC_MCMC_ALGORITHM_H #include "compiler/mcmc/generic_mcmc_config.dtg.h" +#include "utils/containers/transform.h" #include "utils/nonnegative_int/nonnegative_range.h" #include "utils/optional.h" -#include "utils/containers/transform.h" #include "utils/random_utils.h" namespace FlexFlow { -// GeneratingFunc : State -> nn_int -> std::optional -// ScoringFunc : State -> float +// SamplingFn : State -> std::optional +// CostFn : State -> float -template -State minimize_score(State const &starting_state, - GeneratingFunc const &generator, - ScoringFunc const &scorer, - GenericMCMCConfig const &search_config) { +template +State run_mcmc(State const &starting_state, + SamplingFn const &sampler, + CostFn const &cost, + GenericMCMCConfig const &search_config) { State best_state = starting_state; State current_state = best_state; for (nonnegative_int i : nonnegative_range(search_config.num_iterations)) { - std::optional maybe_new_state = transform(generator(current_state, i), [&](State const &s) { - float delta = scorer(s) - scorer(best_state); - if (randf() < exp(-delta / search_config.temperature)) { - if (delta < 0) { - best_state = s; - } - return s; - } - return current_state; - }); - current_state = or_else(maybe_new_state, [&]() {return current_state;}); + std::optional maybe_new_state = + transform(sampler(current_state), [&](State const &s) { + float delta = cost(s) - cost(best_state); + if (randf() < exp(-delta / search_config.temperature)) { + if (delta < 0) { + best_state = s; + } + return s; + } + return current_state; + }); + current_state = or_else(maybe_new_state, [&]() { return current_state; }); } return best_state; } diff --git a/lib/compiler/include/compiler/mcmc/mcmc_over_mapped_pcg.h b/lib/compiler/include/compiler/mcmc/mcmc_over_mapped_pcg.h index bc0d2dfb58..c251340626 100644 --- a/lib/compiler/include/compiler/mcmc/mcmc_over_mapped_pcg.h +++ b/lib/compiler/include/compiler/mcmc/mcmc_over_mapped_pcg.h @@ -12,10 +12,11 @@ namespace FlexFlow { -SearchResult mcmc_graph_optimize(ParallelComputationGraph &pcg, - RuntimeOnlyCostEstimator const &cost_estimator, - MachineSpecification const &resources, - MCMCOverMappedPCGConfig const &search_config); +SearchResult + mcmc_over_mapped_pcg(ParallelComputationGraph &pcg, + RuntimeOnlyCostEstimator const &cost_estimator, + MachineSpecification const &resources, + MCMCOverMappedPCGConfig const &search_config); } // namespace FlexFlow diff --git a/lib/compiler/include/compiler/mcmc/mcmc_over_mapped_pcg_config.struct.toml b/lib/compiler/include/compiler/mcmc/mcmc_over_mapped_pcg_config.struct.toml index e1548a581e..76415ee4d9 100644 --- a/lib/compiler/include/compiler/mcmc/mcmc_over_mapped_pcg_config.struct.toml +++ b/lib/compiler/include/compiler/mcmc/mcmc_over_mapped_pcg_config.struct.toml @@ -20,8 +20,8 @@ name = "num_iterations" type = "::FlexFlow::nonnegative_int" [[fields]] -name = "substitution_interval" -type = "::FlexFlow::nonnegative_int" +name = "substitution_frequency" +type = "float" [[fields]] name = "device_type" diff --git a/lib/compiler/src/compiler/machine_mapping/machine_mapping_mutation_set.cc b/lib/compiler/src/compiler/machine_mapping/machine_mapping_mutation_set.cc index 6b20c9963d..c3c84bb24a 100644 --- a/lib/compiler/src/compiler/machine_mapping/machine_mapping_mutation_set.cc +++ b/lib/compiler/src/compiler/machine_mapping/machine_mapping_mutation_set.cc @@ -5,14 +5,13 @@ #include "utils/containers/vector_of.h" #include "utils/nonnegative_int/nonnegative_range.h" #include "utils/random_utils.h" -#include "utils/vector.h" namespace FlexFlow { std::optional - get_naive_mapping(ParallelComputationGraph &pcg, - MachineSpecification const &resources, - DeviceType const &device_type) { + get_random_mapping(ParallelComputationGraph &pcg, + MachineSpecification const &resources, + DeviceType const &device_type) { std::vector layers = topological_ordering(pcg); std::unordered_map machine_views; for (parallel_layer_guid_t layer : layers) { @@ -22,7 +21,8 @@ std::optional if (allowed_machine_views.empty()) { return std::nullopt; } - machine_views.insert({layer, *(allowed_machine_views.begin())}); + machine_views.insert( + {layer, select_random(vector_of(allowed_machine_views))}); } return MachineMapping{machine_views}; } diff --git a/lib/compiler/src/compiler/mcmc/generic_mcmc_algorithm.cc b/lib/compiler/src/compiler/mcmc/generic_mcmc_algorithm.cc index 1bf4f5c2b7..2c8fcea86d 100644 --- a/lib/compiler/src/compiler/mcmc/generic_mcmc_algorithm.cc +++ b/lib/compiler/src/compiler/mcmc/generic_mcmc_algorithm.cc @@ -1 +1,15 @@ #include "compiler/mcmc/generic_mcmc_algorithm.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using State = value_type<0>; +using SamplingFn = std::function(State)>; +using CostFn = std::function; + +template State run_mcmc(State const &starting_state, + SamplingFn const &sampler, + CostFn const &cost, + GenericMCMCConfig const &search_config); + +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/mcmc/mcmc_over_mapped_pcg.cc b/lib/compiler/src/compiler/mcmc/mcmc_over_mapped_pcg.cc index a5bec534ab..43e80630dd 100644 --- a/lib/compiler/src/compiler/mcmc/mcmc_over_mapped_pcg.cc +++ b/lib/compiler/src/compiler/mcmc/mcmc_over_mapped_pcg.cc @@ -8,54 +8,46 @@ #include "substitutions/pcg_pattern_match.h" #include "substitutions/unity_substitution_set.h" #include "utils/optional.h" +#include "utils/random_utils.h" +#include namespace FlexFlow { -SearchResult mcmc_graph_optimize(ParallelComputationGraph &pcg, - RuntimeOnlyCostEstimator const &cost_estimator, - MachineSpecification const &resources, - MCMCOverMappedPCGConfig const &search_config) { +SearchResult + mcmc_over_mapped_pcg(ParallelComputationGraph &pcg, + RuntimeOnlyCostEstimator const &cost_estimator, + MachineSpecification const &resources, + MCMCOverMappedPCGConfig const &search_config) { std::vector substitutions = get_substitution_set(resources); + MachineMapping random_mapping = assert_unwrap( + get_random_mapping(pcg, resources, search_config.device_type)); + SearchResult starting_state = SearchResult{pcg, random_mapping}; - std::optional naive_mapping = - get_naive_mapping(pcg, resources, search_config.device_type); - if (naive_mapping == std::nullopt) { - throw std::runtime_error("Failed to find any solutions"); - } - - SearchResult starting_state = SearchResult{pcg, naive_mapping.value()}; - - auto generating_func = [&](SearchResult mapped_pcg, - nonnegative_int i) -> std::optional { - if (i.unwrap_nonnegative() % - search_config.substitution_interval.unwrap_nonnegative() == - 0) { - // substitutions every (substitution_interval) iterations - std::optional random_substitution = - get_random_substitution(resources); - if (random_substitution != std::nullopt) { - std::optional pattern_match = - get_random_pattern_match(random_substitution.value().pcg_pattern, - sub_pcg_from_full_pcg(mapped_pcg.pcg)); - if (pattern_match != std::nullopt) { - return apply_substitution_and_update_machine_mapping( - mapped_pcg, random_substitution.value(), pattern_match.value()); - } - } - return std::nullopt; + auto sampler = [&](SearchResult mapped_pcg) -> std::optional { + // applies substitution with substitution_frequency probability + // applies machine mapping mutation with (1 - substitution_frequency) + // probability + ASSERT(search_config.substitution_frequency >= 0 && + search_config.substitution_frequency <= 1); + if (randf() < search_config.substitution_frequency) { + Substitution random_substitution = + assert_unwrap(get_random_substitution(resources)); + std::optional maybe_pattern_match = + get_random_pattern_match(random_substitution.pcg_pattern, + sub_pcg_from_full_pcg(mapped_pcg.pcg)); + return transform(maybe_pattern_match, [&](PCGPatternMatch match) { + return apply_substitution_and_update_machine_mapping( + mapped_pcg, random_substitution, match); + }); } else { - // machine mapping mutations otherwise - std::optional new_machine_mapping = - get_random_mutation(mapped_pcg, resources, search_config.device_type); - if (new_machine_mapping == std::nullopt) { - return std::nullopt; - } - return SearchResult{mapped_pcg.pcg, new_machine_mapping.value()}; + MachineMapping new_machine_mapping = assert_unwrap(get_random_mutation( + mapped_pcg, resources, search_config.device_type)); + return SearchResult{mapped_pcg.pcg, new_machine_mapping}; } }; - auto scoring_func = [&](SearchResult mapped_pcg) -> float { + auto cost = [&](SearchResult mapped_pcg) -> float { return task_simulator_estimate_forward_pass_time(mapped_pcg.pcg, cost_estimator, mapped_pcg.machine_mapping, @@ -67,8 +59,7 @@ SearchResult mcmc_graph_optimize(ParallelComputationGraph &pcg, GenericMCMCConfig{/*temperature*/ search_config.temperature, /*num_iterations*/ search_config.num_iterations}; - SearchResult result = - minimize_score(starting_state, generating_func, scoring_func, config); + SearchResult result = run_mcmc(starting_state, sampler, cost, config); return result; } diff --git a/lib/compiler/test/src/compiler/mcmc/generic_mcmc_algorithm.cc b/lib/compiler/test/src/compiler/mcmc/generic_mcmc_algorithm.cc index b020006f4d..b21ee4333f 100644 --- a/lib/compiler/test/src/compiler/mcmc/generic_mcmc_algorithm.cc +++ b/lib/compiler/test/src/compiler/mcmc/generic_mcmc_algorithm.cc @@ -6,9 +6,8 @@ using namespace FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("generic_mcmc_algorithm") { float starting_state = 0.1; - auto generating_func = [](float x, - nonnegative_int i) -> std::optional { - float new_x = x + (randf() - 0.5) / (i.unwrap_nonnegative() + 1); + auto sampler = [](float x) -> std::optional { + float new_x = x + (randf() - 0.5); if (new_x < 0) { return std::nullopt; } @@ -17,12 +16,11 @@ TEST_SUITE(FF_TEST_SUITE) { } return new_x; }; - auto scoring_func = [](float x) { return (x - 0.5) * (x - 0.5); }; + auto cost = [](float x) { return (x - 0.5) * (x - 0.5); }; GenericMCMCConfig config = GenericMCMCConfig{/*temperature=*/1.0, - /*num_iterations=*/50_n}; - float answer = - minimize_score(starting_state, generating_func, scoring_func, config); - float error = scoring_func(answer); + /*num_iterations=*/100_n}; + float answer = run_mcmc(starting_state, sampler, cost, config); + float error = cost(answer); CHECK(answer > 0.47); CHECK(answer < 0.53); CHECK(error >= 0); diff --git a/lib/compiler/test/src/compiler/mcmc/mcmc_over_mapped_pcg.cc b/lib/compiler/test/src/compiler/mcmc/mcmc_over_mapped_pcg.cc index 87cffa869f..9e2134d08b 100644 --- a/lib/compiler/test/src/compiler/mcmc/mcmc_over_mapped_pcg.cc +++ b/lib/compiler/test/src/compiler/mcmc/mcmc_over_mapped_pcg.cc @@ -14,7 +14,7 @@ using namespace FlexFlow; TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("mcmc_graph_optimize") { + TEST_CASE("mcmc_over_mapped_pcg") { ComputationGraph cg = [&] { ComputationGraphBuilder b; TensorShape input_tensor_shape = TensorShape{ @@ -61,11 +61,11 @@ TEST_SUITE(FF_TEST_SUITE) { MCMCOverMappedPCGConfig no_search = MCMCOverMappedPCGConfig{/*temperature=*/1.0, /*num_iterations=*/1_n, - /*substitution_interval=*/5_n, + /*substitution_frequency=*/0.2, /*device_type=*/DeviceType::GPU}; SearchResult base_result = - mcmc_graph_optimize(pcg, cost_estimator, full_machine_spec, no_search); + mcmc_over_mapped_pcg(pcg, cost_estimator, full_machine_spec, no_search); float base_runtime = task_simulator_estimate_forward_pass_time(base_result.pcg, cost_estimator, @@ -76,10 +76,10 @@ TEST_SUITE(FF_TEST_SUITE) { MCMCOverMappedPCGConfig search_config = MCMCOverMappedPCGConfig{/*temperature=*/1.0, /*num_iterations=*/100_n, - /*substitution_interval=*/5_n, + /*substitution_frequency=*/0.2, /*device_type=*/DeviceType::GPU}; - SearchResult result = mcmc_graph_optimize( + SearchResult result = mcmc_over_mapped_pcg( pcg, cost_estimator, full_machine_spec, search_config); float runtime = task_simulator_estimate_forward_pass_time(result.pcg, diff --git a/lib/utils/include/utils/optional.h b/lib/utils/include/utils/optional.h index 377561d70c..4e4bc03cd4 100644 --- a/lib/utils/include/utils/optional.h +++ b/lib/utils/include/utils/optional.h @@ -3,6 +3,7 @@ #include "utils/exception.h" #include "utils/fmt/optional.h" +#include #include namespace FlexFlow { @@ -28,7 +29,7 @@ T const &unwrap(std::optional const &o, F const &f) { template T const &assert_unwrap(std::optional const &o) { - assert(o.has_value()); + ASSERT(o.has_value()); return o.value(); } From cd99432ff83605275d7028d0376c864e15d5017d Mon Sep 17 00:00:00 2001 From: Victor Li Date: Wed, 3 Dec 2025 19:25:33 -0800 Subject: [PATCH 5/5] refactoring apply_substitution_and_update_machine_mapping, will look into moving it to the other file, also put the evaluate_substitution_ouput outside of apply_substitution since it has a side effect, thinking of making the non-eval-sub version the default --- ...substitution_and_update_machine_mapping.cc | 189 ++++-------------- .../apply_substitution/apply_substitution.h | 8 + .../apply_substitution/apply_substitution.cc | 15 +- 3 files changed, 56 insertions(+), 156 deletions(-) diff --git a/lib/compiler/src/compiler/machine_mapping/apply_substitution_and_update_machine_mapping.cc b/lib/compiler/src/compiler/machine_mapping/apply_substitution_and_update_machine_mapping.cc index 252384985b..2cb78a38b6 100644 --- a/lib/compiler/src/compiler/machine_mapping/apply_substitution_and_update_machine_mapping.cc +++ b/lib/compiler/src/compiler/machine_mapping/apply_substitution_and_update_machine_mapping.cc @@ -9,12 +9,16 @@ #include "substitutions/sub_parallel_computation_graph.h" #include "substitutions/sub_parallel_computation_graph_data.dtg.h" #include "substitutions/sub_parallel_computation_graph_edge.h" +#include "utils/containers/filter.h" #include "utils/containers/is_subseteq_of.h" #include "utils/containers/keys.h" #include "utils/containers/merge_maps.h" #include "utils/containers/restrict_keys.h" #include "utils/containers/set_minus.h" #include "utils/containers/values.h" +#include "utils/containers/vector_of.h" +#include "utils/random_utils.h" +#include namespace FlexFlow { @@ -24,174 +28,51 @@ SearchResult apply_substitution_and_update_machine_mapping( PCGPatternMatch const &match) { SubParallelComputationGraph spcg = sub_pcg_from_full_pcg(mapped_pcg.pcg); - auto substitution_output_result = - evaluate_substitution_output(spcg, sub, match); - SubParallelComputationGraph substitution_output_graph = - substitution_output_result.first; - OutputExprToResultSubPCGMapping output_expr_to_result_sub_pcg_mapping = - substitution_output_result.second; + std::pair + substitution_output_result = + evaluate_substitution_output(spcg, sub, match); - SubParallelComputationGraphData output_graph_data = - get_sub_pcg_data(substitution_output_graph); - SubParallelComputationGraphData pre_data = get_sub_pcg_data(spcg); + SubParallelComputationGraph post_substitution_graph = + apply_substitution_from_output_result( + substitution_output_result, spcg, sub, match); - std::unordered_set pre_nodes = - keys(pre_data.node_data); - std::unordered_set matched_nodes = - unordered_set_of(values(match.node_assignment)); - std::unordered_set post_nodes_from_original_graph = - set_minus(pre_nodes, matched_nodes); + std::unordered_map post_node_data = + get_sub_pcg_data(post_substitution_graph).node_data; + + std::unordered_set + substitution_output_parallel_layers = + get_parallel_layers(substitution_output_result.first); std::unordered_map machine_views = mapped_pcg.machine_mapping.machine_views; - std::unordered_set substituted_machine_views = + std::unordered_set matched_nodes = + unordered_set_of(values(match.node_assignment)); + + std::vector substituted_machine_views = vector_of( transform(matched_nodes, [&](parallel_layer_guid_t const &node) { return machine_views.at(node); - }); - MachineView first_substituted_machine_view = - *substituted_machine_views.begin(); - - std::unordered_map post_node_data = - [&] { - std::unordered_map - post_node_data_from_orig = restrict_keys( - pre_data.node_data, post_nodes_from_original_graph); - std::unordered_map - post_node_data_from_sub = output_graph_data.node_data; - - for (auto [layer, attrs] : post_node_data_from_sub) { - machine_views.insert_or_assign(layer, first_substituted_machine_view); - } - - return merge_disjoint_maps(post_node_data_from_orig, - post_node_data_from_sub); - }(); - - std::unordered_set post_edges = [&] { - std::unordered_set post_edges_from_orig = - filter(pre_data.edges, [&](SubParallelComputationGraphEdge const &e) { - if (e.raw_edge.has()) { - return true; - } else { - DataflowEdge dfe = e.raw_edge.get(); - parallel_layer_guid_t src = parallel_layer_guid_t{dfe.src.node}; - parallel_layer_guid_t dst = parallel_layer_guid_t{dfe.dst.node}; - return !(contains(matched_nodes, src) || - contains(matched_nodes, dst)); - } - }); + })); - std::unordered_set post_edges_from_sub = - filter(output_graph_data.edges, - [&](SubParallelComputationGraphEdge const &e) { - return !e.raw_edge.has(); - }); - - bidict - output_orig_pattern_mapping = get_output_mapping_for_pcg_pattern_match( - match, sub.pcg_pattern, spcg); - bidict - output_post_outexpr_mapping = get_output_graph_expr_output_mapping( - output_expr_to_result_sub_pcg_mapping, - sub.output_graph_expr, - substitution_output_graph); - - std::unordered_set incoming_to_sub_edges; - for (auto const &[pattern_input, base_graph_tensor] : - match.input_assignment) { - OutputGraphExprInput output_expr_input = - sub.inputs_mapping.at_l(pattern_input); - input_parallel_tensor_guid_t output_graph_input = - output_expr_to_result_sub_pcg_mapping.input_mapping.at_r( - output_expr_input); - std::unordered_set uses = get_parallel_tensor_uses( - substitution_output_graph, - open_parallel_tensor_guid_from_input(output_graph_input)); - for (parallel_tensor_use_t const &use : uses) { - SubParallelComputationGraphEdge new_edge = - subpcg_edge_from_tensor_and_use(base_graph_tensor, use); - incoming_to_sub_edges.insert(new_edge); - } - } - - std::unordered_set outgoing_from_sub_edges; - for (ParallelComputationGraphEdge const &outgoing_edge : - get_subgraph_outgoing_edges(spcg, matched_nodes)) { - parallel_tensor_guid_t original_tensor = - get_parallel_tensor(outgoing_edge); - PatternNodeOutput pattern_tensor = - output_orig_pattern_mapping.at_r(original_tensor); - OutputGraphExprNodeOutput output_graph_tensor = - sub.outputs_mapping.at_l(pattern_tensor); - parallel_tensor_guid_t new_tensor = - output_post_outexpr_mapping.at_r(output_graph_tensor); - - SubParallelComputationGraphEdge new_edge = - subpcg_edge_from_tensor_and_dst( - new_tensor, - get_dst_layer(outgoing_edge), - get_dst_layer_input_idx(outgoing_edge)); - outgoing_from_sub_edges.insert(new_edge); - } - - return set_union(std::vector{ - post_edges_from_orig, - post_edges_from_sub, - incoming_to_sub_edges, - outgoing_from_sub_edges, - }); - }(); - - std::unordered_set post_inputs = - pre_data.inputs; - - std::unordered_map - post_value_data = [&] { - std::unordered_map - post_value_data_from_orig = filter_keys( - pre_data.value_data, [&](open_parallel_tensor_guid_t const &t) { - return visit_open_parallel_tensor_guid( - t, - overload{ - [&](parallel_tensor_guid_t const &t) { - return contains(post_nodes_from_original_graph, - get_source_layer(t)); - }, - [](input_parallel_tensor_guid_t const &) { - return true; - }, - }); - }); - - std::unordered_map - post_value_data_from_sub = output_graph_data.value_data; - return merge_disjoint_maps(post_value_data_from_orig, - post_value_data_from_sub); - }(); - - SubParallelComputationGraphData post_data = SubParallelComputationGraphData{ - post_node_data, - post_edges, - post_inputs, - post_value_data, - }; + for (parallel_layer_guid_t layer : substitution_output_parallel_layers) { + machine_views.insert_or_assign(layer, + select_random(substituted_machine_views)); + } - assert(is_subseteq_of(keys(post_node_data), keys(machine_views))); + ASSERT(is_subseteq_of(keys(post_node_data), keys(machine_views))); - for (auto it = machine_views.begin(); it != machine_views.end();) { - if (post_node_data.find(it->first) == post_node_data.end()) { - it = machine_views.erase(it); - } else { - ++it; - } - } + std::unordered_map + post_node_machine_views = + filter(machine_views, + [&](std::pair const &p) { + return post_node_data.count(p.first); + }); - assert(keys(post_node_data) == keys(machine_views)); + ASSERT(keys(post_node_data) == keys(post_node_machine_views)); return SearchResult{ - pcg_from_sub_pcg_by_dropping_inputs(sub_pcg_from_graph_data(post_data)), - MachineMapping{machine_views}}; + pcg_from_sub_pcg_by_dropping_inputs(post_substitution_graph), + MachineMapping{post_node_machine_views}}; } } // namespace FlexFlow diff --git a/lib/substitutions/include/substitutions/apply_substitution/apply_substitution.h b/lib/substitutions/include/substitutions/apply_substitution/apply_substitution.h index 92f7bb1c03..d46523ecb6 100644 --- a/lib/substitutions/include/substitutions/apply_substitution/apply_substitution.h +++ b/lib/substitutions/include/substitutions/apply_substitution/apply_substitution.h @@ -1,6 +1,7 @@ #ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_APPLY_SUBSTITUTION_APPLY_SUBSTITUTION_H #define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_APPLY_SUBSTITUTION_APPLY_SUBSTITUTION_H +#include "substitutions/apply_substitution/output_expr_to_result_sub_pcg_mapping.h" #include "substitutions/pcg_pattern_match.dtg.h" #include "substitutions/sub_parallel_computation_graph.dtg.h" #include "substitutions/substitution.dtg.h" @@ -26,6 +27,13 @@ SubParallelComputationGraph Substitution const &substitution, PCGPatternMatch const &match); +SubParallelComputationGraph apply_substitution_from_output_result( + std::pair + substitution_output_result, + SubParallelComputationGraph const &spcg, + Substitution const &sub, + PCGPatternMatch const &match); + } // namespace FlexFlow #endif diff --git a/lib/substitutions/src/substitutions/apply_substitution/apply_substitution.cc b/lib/substitutions/src/substitutions/apply_substitution/apply_substitution.cc index 61bfe15d7b..611296488e 100644 --- a/lib/substitutions/src/substitutions/apply_substitution/apply_substitution.cc +++ b/lib/substitutions/src/substitutions/apply_substitution/apply_substitution.cc @@ -20,8 +20,19 @@ SubParallelComputationGraph apply_substitution(SubParallelComputationGraph const &spcg, Substitution const &sub, PCGPatternMatch const &match) { - auto substitution_output_result = - evaluate_substitution_output(spcg, sub, match); + std::pair + substitution_output_result = + evaluate_substitution_output(spcg, sub, match); + return apply_substitution_from_output_result( + substitution_output_result, spcg, sub, match); +} + +SubParallelComputationGraph apply_substitution_from_output_result( + std::pair + substitution_output_result, + SubParallelComputationGraph const &spcg, + Substitution const &sub, + PCGPatternMatch const &match) { SubParallelComputationGraph substitution_output_graph = substitution_output_result.first; OutputExprToResultSubPCGMapping output_expr_to_result_sub_pcg_mapping =