diff --git a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_constraint.h b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_constraint.h index c2c11fac51..1985e5c03c 100644 --- a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_constraint.h +++ b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_constraint.h @@ -9,8 +9,8 @@ OperatorAttributeConstraint op_type_equals_constraint(OperatorType); OperatorAttributeConstraint op_attr_key_equals(OperatorAttributeKey, OperatorAttributeValue const &); -OperatorAttributeConstraint - op_attr_key_divisible_by(OperatorAttributeKey, nonnegative_int denominator); +OperatorAttributeConstraint op_attr_key_divisible_by(OperatorAttributeKey, + positive_int denominator); OperatorAttributeConstraint make_equals_constraint(OperatorAttributeExpr const &, OperatorAttributeValue const &); diff --git a/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_pattern.h b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_pattern.h index c1e28f8d8f..99e80eaa7f 100644 --- a/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_pattern.h +++ b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_pattern.h @@ -2,13 +2,13 @@ #define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_TENSOR_ATTRIBUTE_PATTERN_H #include "substitutions/tensor_pattern/tensor_attribute_pattern.dtg.h" -#include "utils/nonnegative_int/nonnegative_int.h" +#include "utils/positive_int/positive_int.h" namespace FlexFlow { TensorAttributePattern tensor_attribute_pattern_match_all(); TensorAttributePattern - tensor_attr_pattern_require_num_dims(nonnegative_int num_dims); + tensor_attr_pattern_require_num_dims(positive_int num_dims); } // namespace FlexFlow diff --git a/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_value.dtg.toml b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_value.dtg.toml index ffacfafbdf..3d65a88f89 100644 --- a/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_value.dtg.toml +++ b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_value.dtg.toml @@ -14,10 +14,18 @@ includes = [ "utils/hash/vector.h", "utils/fmt/vector.h", "utils/nonnegative_int/nonnegative_int.h", + "utils/positive_int/positive_int.h", ] [[values]] type = "::FlexFlow::nonnegative_int" +[[values]] +type = "::FlexFlow::positive_int" + [[values]] type = "std::vector<::FlexFlow::nonnegative_int>" + +[[values]] +type = "std::vector<::FlexFlow::positive_int>" + diff --git a/lib/substitutions/include/substitutions/unity_substitution_set.h b/lib/substitutions/include/substitutions/unity_substitution_set.h index be1a2101d0..074d41dc71 100644 --- a/lib/substitutions/include/substitutions/unity_substitution_set.h +++ b/lib/substitutions/include/substitutions/unity_substitution_set.h @@ -10,36 +10,25 @@ namespace FlexFlow { std::vector get_substitution_set(MachineComputeSpecification const &resources); -Substitution create_combine_inception(nonnegative_int num_convs, - nonnegative_int num_dims, - nonnegative_int degree); -Substitution create_combine_concat(nonnegative_int num_inputs, - nonnegative_int num_dims, - nonnegative_int degree); -Substitution create_replicate_linear_combine(nonnegative_int num_dims, - nonnegative_int degree, +Substitution create_replicate_linear_combine(positive_int num_dims, + positive_int degree, bool use_bias); -Substitution create_partition_linear_combine(nonnegative_int num_dims, - nonnegative_int degree, - Activation activation, +Substitution create_partition_linear_combine(positive_int num_dims, + positive_int degree, bool use_bias); -Substitution create_partition_conv2d_combine(nonnegative_int num_dims, - nonnegative_int degree); -Substitution create_partition_attention_combine(nonnegative_int num_heads, - nonnegative_int degree); -Substitution create_replicate_attention_reduce(nonnegative_int num_heads, - nonnegative_int degree); +Substitution create_partition_conv2d_combine(positive_int num_dims, + positive_int degree); +Substitution create_partition_attention_combine(positive_int num_heads, + positive_int degree); +Substitution create_replicate_attention_reduce(positive_int num_heads, + positive_int degree); Substitution create_partition_add_combine(ff_dim_t parallel_dim, - nonnegative_int degree); + positive_int degree); Substitution create_partition_relu_combine(ff_dim_t parallel_dim, - nonnegative_int degree); -Substitution create_partition_concat_combine(nonnegative_int num_inputs, - ff_dim_t concat_dim, - ff_dim_t parallel_dim, - nonnegative_int degree); + positive_int degree); Substitution create_partition_softmax_combine(ff_dim_t softmax_dim, ff_dim_t partition_dim, - nonnegative_int degree); + positive_int degree); Substitution create_fuse_linear_activation(Activation activation); } // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/apply_substitution/perform_shape_inference.cc b/lib/substitutions/src/substitutions/apply_substitution/perform_shape_inference.cc index e7dc926682..8e1c06b9b5 100644 --- a/lib/substitutions/src/substitutions/apply_substitution/perform_shape_inference.cc +++ b/lib/substitutions/src/substitutions/apply_substitution/perform_shape_inference.cc @@ -1,8 +1,10 @@ #include "substitutions/apply_substitution/perform_shape_inference.h" #include "op-attrs/get_incoming_tensor_roles.h" #include "op-attrs/shape_inference.h" +#include "utils/containers/binary_merge_disjoint_maps.h" #include "utils/containers/filter_values.h" #include "utils/containers/filtrans.h" +#include "utils/containers/is_subseteq_of.h" #include "utils/containers/map_keys.h" #include "utils/containers/map_values.h" #include "utils/containers/restrict_keys.h" @@ -54,6 +56,8 @@ LabelledOpenKwargDataflowGraphView incoming_tensor_roles = get_incoming_tensor_roles(n_attrs.op_attrs); + ASSERT(is_subseteq_of(keys(incoming_shapes), keys(incoming_tensor_roles))); + auto incoming_shapes_with_role = [&](IncomingTensorRole role) -> std::unordered_map { std::unordered_set slots_with_desired_role = @@ -68,6 +72,9 @@ LabelledOpenKwargDataflowGraphView weight_shapes = incoming_shapes_with_role(IncomingTensorRole::WEIGHT); + ASSERT(binary_merge_disjoint_maps(input_shapes, weight_shapes) == + incoming_shapes); + std::unordered_map inferred_weight_shapes = get_weight_shapes(n_attrs.op_attrs, input_shapes); diff --git a/lib/substitutions/src/substitutions/operator_pattern/get_attribute.cc b/lib/substitutions/src/substitutions/operator_pattern/get_attribute.cc index f7fce1aca7..19349823b7 100644 --- a/lib/substitutions/src/substitutions/operator_pattern/get_attribute.cc +++ b/lib/substitutions/src/substitutions/operator_pattern/get_attribute.cc @@ -83,6 +83,8 @@ std::optional get_attribute(ConcatAttrs const &p, std::optional get_attribute(Conv2DAttrs const &p, OperatorAttributeKey key) { switch (key) { + case OperatorAttributeKey::OUT_CHANNELS: + return OperatorAttributeValue{p.out_channels}; case OperatorAttributeKey::OP_TYPE: return OperatorAttributeValue{get_op_type(p)}; case OperatorAttributeKey::KERNEL_H: @@ -113,6 +115,12 @@ std::optional get_attribute(ElementBinaryAttrs const &p, switch (key) { case OperatorAttributeKey::OP_TYPE: return OperatorAttributeValue{get_op_type(p)}; + case OperatorAttributeKey::DATA_TYPE: + return OperatorAttributeValue{p.compute_type}; + case OperatorAttributeKey::SHOULD_BROADCAST_LHS: + return OperatorAttributeValue{p.should_broadcast_lhs}; + case OperatorAttributeKey::SHOULD_BROADCAST_RHS: + return OperatorAttributeValue{p.should_broadcast_rhs}; default: return std::nullopt; } @@ -123,6 +131,8 @@ std::optional get_attribute(ElementUnaryAttrs const &p, switch (key) { case OperatorAttributeKey::OP_TYPE: return OperatorAttributeValue{get_op_type(p)}; + case OperatorAttributeKey::SCALAR: + return OperatorAttributeValue{p.scalar}; default: return std::nullopt; } @@ -227,10 +237,20 @@ std::optional switch (key) { case OperatorAttributeKey::OP_TYPE: return OperatorAttributeValue{get_op_type(p)}; + case OperatorAttributeKey::EMBED_DIM: + return OperatorAttributeValue{p.embed_dim}; + case OperatorAttributeKey::KDIM: + return OperatorAttributeValue{p.kdim}; + case OperatorAttributeKey::VDIM: + return OperatorAttributeValue{p.vdim}; case OperatorAttributeKey::NUM_HEADS: return OperatorAttributeValue{p.num_heads}; - case OperatorAttributeKey::USE_BIAS: + case OperatorAttributeKey::BIAS: return OperatorAttributeValue{p.bias}; + case OperatorAttributeKey::ADD_BIAS_KV: + return OperatorAttributeValue{p.add_bias_kv}; + case OperatorAttributeKey::ADD_ZERO_ATTN: + return OperatorAttributeValue{p.add_bias_kv}; case OperatorAttributeKey::DROPOUT: return OperatorAttributeValue{p.dropout}; default: diff --git a/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_constraint.cc b/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_constraint.cc index 29aef07e3a..a45af1e7d4 100644 --- a/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_constraint.cc +++ b/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_constraint.cc @@ -20,9 +20,8 @@ OperatorAttributeConstraint }; } -OperatorAttributeConstraint - op_attr_key_divisible_by(OperatorAttributeKey key, - nonnegative_int denominator) { +OperatorAttributeConstraint op_attr_key_divisible_by(OperatorAttributeKey key, + positive_int denominator) { return OperatorAttributeConstraint{ ConstraintType::DIVISIBLE_BY, OperatorAttributeExpr{key}, diff --git a/lib/substitutions/src/substitutions/output_graph/materialize_operator_from_attrs_map.cc b/lib/substitutions/src/substitutions/output_graph/materialize_operator_from_attrs_map.cc index 11ef85984c..a765556c63 100644 --- a/lib/substitutions/src/substitutions/output_graph/materialize_operator_from_attrs_map.cc +++ b/lib/substitutions/src/substitutions/output_graph/materialize_operator_from_attrs_map.cc @@ -61,7 +61,6 @@ PCGOperatorAttrs materialize_operator_from_attrs_map( case OperatorType::NOOP: case OperatorType::INPUT: case OperatorType::WEIGHT: - case OperatorType::CONV2D: case OperatorType::DROPOUT: case OperatorType::LINEAR: return PCGOperatorAttrs{LinearAttrs{ @@ -75,19 +74,72 @@ PCGOperatorAttrs materialize_operator_from_attrs_map( acc.get>( OperatorAttributeKey::REGULARIZER), }}; + case OperatorType::CONV2D: + return PCGOperatorAttrs{Conv2DAttrs{ + /*out_channels=*/acc.get( + OperatorAttributeKey::OUT_CHANNELS), + /*kernel_h=*/acc.get(OperatorAttributeKey::KERNEL_H), + /*kernel_w=*/acc.get(OperatorAttributeKey::KERNEL_W), + /*stride_h=*/acc.get(OperatorAttributeKey::STRIDE_H), + /*stride_w=*/acc.get(OperatorAttributeKey::STRIDE_W), + /*padding_h=*/ + acc.get(OperatorAttributeKey::PADDING_H), + /*padding_w=*/ + acc.get(OperatorAttributeKey::PADDING_W), + /*groups=*/acc.get(OperatorAttributeKey::GROUPS), + /*activation=*/ + acc.get>(OperatorAttributeKey::ACTIVATION), + /*use_bias=*/acc.get(OperatorAttributeKey::USE_BIAS), + }}; + case OperatorType::RELU: + return PCGOperatorAttrs{ElementUnaryAttrs{ + acc.get(OperatorAttributeKey::OP_TYPE), + acc.get>(OperatorAttributeKey::SCALAR), + }}; + case OperatorType::SOFTMAX: + return PCGOperatorAttrs{SoftmaxAttrs{ + acc.get(OperatorAttributeKey::AXIS), + }}; + case OperatorType::EW_ADD: + return PCGOperatorAttrs{ElementBinaryAttrs{ + acc.get(OperatorAttributeKey::OP_TYPE), + acc.get(OperatorAttributeKey::DATA_TYPE), + acc.get(OperatorAttributeKey::SHOULD_BROADCAST_LHS), + acc.get(OperatorAttributeKey::SHOULD_BROADCAST_LHS), + }}; + case OperatorType::REPLICATE: + return PCGOperatorAttrs{ReplicateAttrs{ + /*replicate_degree=*/acc.get( + OperatorAttributeKey::PARALLEL_DEGREE), + }}; + case OperatorType::REPARTITION: + return PCGOperatorAttrs{RepartitionAttrs{ + /*repartition_dim=*/acc.get( + OperatorAttributeKey::PARALLEL_DIM), + /*repartition_Degree=*/ + acc.get(OperatorAttributeKey::PARALLEL_DEGREE), + }}; + case OperatorType::COMBINE: + return PCGOperatorAttrs{CombineAttrs{ + /*combine_dim=*/acc.get(OperatorAttributeKey::PARALLEL_DIM), + /*combine_degree=*/ + acc.get(OperatorAttributeKey::PARALLEL_DEGREE), + }}; + case OperatorType::REDUCTION: + return PCGOperatorAttrs{ReductionAttrs{ + acc.get(OperatorAttributeKey::PARALLEL_DEGREE), + }}; case OperatorType::BATCHMATMUL: case OperatorType::SCALAR_MULTIPLY: case OperatorType::SCALAR_ADD: case OperatorType::SCALAR_FLOOR_DIV: case OperatorType::SCALAR_TRUE_DIV: case OperatorType::SCALAR_SUB: - case OperatorType::RELU: case OperatorType::IDENTITY: case OperatorType::SIGMOID: case OperatorType::TANH: case OperatorType::ELU: case OperatorType::FLAT: - case OperatorType::SOFTMAX: case OperatorType::BATCHNORM: case OperatorType::CONCAT: case OperatorType::SPLIT: @@ -96,7 +148,6 @@ PCGOperatorAttrs materialize_operator_from_attrs_map( case OperatorType::RESHAPE: case OperatorType::REVERSE: case OperatorType::TRANSPOSE: - case OperatorType::EW_ADD: case OperatorType::EW_MUL: case OperatorType::MATMUL: case OperatorType::MUL: @@ -143,10 +194,6 @@ PCGOperatorAttrs materialize_operator_from_attrs_map( case OperatorType::LAYERNORM: case OperatorType::GATHER: case OperatorType::BROADCAST: - case OperatorType::REPARTITION: - case OperatorType::COMBINE: - case OperatorType::REPLICATE: - case OperatorType::REDUCTION: case OperatorType::BATCH: case OperatorType::PIPELINE: case OperatorType::FUSED_PARALLEL: diff --git a/lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_pattern.cc b/lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_pattern.cc index e1c1fe7cf6..f224c6883d 100644 --- a/lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_pattern.cc +++ b/lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_pattern.cc @@ -8,7 +8,7 @@ TensorAttributePattern tensor_attribute_pattern_match_all() { } TensorAttributePattern - tensor_attr_pattern_require_num_dims(nonnegative_int num_dims) { + tensor_attr_pattern_require_num_dims(positive_int num_dims) { return TensorAttributePattern{{ TensorAttributeConstraint{ ConstraintType::EQUAL, diff --git a/lib/substitutions/src/substitutions/unity_substitution_set.cc b/lib/substitutions/src/substitutions/unity_substitution_set.cc index 469bc02799..f1d808c5fd 100644 --- a/lib/substitutions/src/substitutions/unity_substitution_set.cc +++ b/lib/substitutions/src/substitutions/unity_substitution_set.cc @@ -4,24 +4,55 @@ #include "substitutions/output_graph/output_operator_attrs_assignment.h" #include "substitutions/substitution_builder.h" #include "substitutions/tensor_pattern/tensor_attribute_pattern.h" -#include "utils/containers/get_only.h" #include "utils/containers/require_only_key.h" #include "utils/nonnegative_int/nonnegative_int.h" #include "utils/nonnegative_int/nonnegative_range.h" +#include "utils/positive_int/positive_range.h" namespace FlexFlow { std::vector get_substitution_set(MachineComputeSpecification const &resources) { std::vector substitutions; - for (nonnegative_int num_dims : - nonnegative_range(1_n, nonnegative_int{MAX_TENSOR_DIM})) { - for (nonnegative_int degree = 1_n; degree <= get_num_gpus(resources); - degree *= 2_n) { + + positive_int max_tensor_dim = positive_int{MAX_TENSOR_DIM}; + + for (positive_int dim : positive_range(1_p, max_tensor_dim + 1_p)) { + for (positive_int degree = 1_p; degree <= get_num_gpus(resources); + degree *= 2_p) { + substitutions.push_back( + create_replicate_linear_combine(dim, degree, true)); substitutions.push_back( - create_replicate_linear_combine(num_dims, degree, true)); + create_replicate_linear_combine(dim, degree, false)); substitutions.push_back( - create_replicate_linear_combine(num_dims, degree, false)); + create_partition_linear_combine(dim, degree, true)); + substitutions.push_back( + create_partition_linear_combine(dim, degree, false)); + substitutions.push_back(create_partition_relu_combine( + ff_dim_t{dim.nonnegative_int_from_positive_int()}, degree)); + substitutions.push_back(create_partition_add_combine( + ff_dim_t{dim.nonnegative_int_from_positive_int()}, degree)); + substitutions.push_back(create_partition_attention_combine(dim, degree)); + substitutions.push_back(create_replicate_attention_reduce(dim, degree)); + } + } + + for (positive_int degree = 1_p; degree <= get_num_gpus(resources); + degree *= 2_p) { + substitutions.push_back(create_partition_conv2d_combine(4_p, degree)); + } + + for (positive_int partition_dim : positive_range(1_p, max_tensor_dim + 1_p)) { + for (positive_int softmax_dim : positive_range(1_p, max_tensor_dim + 1_p)) { + for (positive_int degree = 1_p; degree <= get_num_gpus(resources); + degree *= 2_p) { + if (partition_dim != softmax_dim) { + substitutions.push_back(create_partition_softmax_combine( + ff_dim_t{partition_dim.nonnegative_int_from_positive_int()}, + ff_dim_t{softmax_dim.nonnegative_int_from_positive_int()}, + degree)); + } + } } } substitutions.push_back(create_fuse_linear_activation(Activation::RELU)); @@ -31,20 +62,116 @@ std::vector return substitutions; } -Substitution create_combine_inception(nonnegative_int num_convs, - nonnegative_int num_dims, - nonnegative_int degree) { - NOT_IMPLEMENTED(); +static PatternValue insert_single_output_pattern( + SubstitutionBuilder &b, + OperatorAttributePattern const &attribute_pattern, + std::unordered_map const &inputs, + TensorAttributePattern const &output_pattern, + std::string const &name) { + return require_only_key(b.add_pattern_node(attribute_pattern, + inputs, + /*output_patterns=*/ + { + { + TensorSlotName::OUTPUT, + output_pattern, + }, + }, + name), + TensorSlotName::OUTPUT); +} + +static OutputGraphExprValue insert_single_output_op( + SubstitutionBuilder &b, + OutputOperatorAttrsAssignment const &expr, + std::unordered_map const &inputs) { + return require_only_key( + b.add_output_graph_node(expr, inputs, {TensorSlotName::OUTPUT}), + TensorSlotName::OUTPUT); +} + +static OutputGraphExprValue + insert_replicate_or_reduce(OperatorType op_type, + SubstitutionBuilder &b, + positive_int degree, + OutputGraphExprValue const &input) { + + ASSERT(op_type == OperatorType::REPLICATE || + op_type == OperatorType::REDUCTION); + + OutputOperatorAttrsAssignment replicate_expr = OutputOperatorAttrsAssignment{ + std::nullopt, + { + set_op_type_attr(op_type), + set_attr_to_constant(OperatorAttributeKey::PARALLEL_DEGREE, + OperatorAttributeValue{degree}), + }}; + + return insert_single_output_op( + b, replicate_expr, {{TensorSlotName::INPUT, input}}); +} + +static OutputGraphExprValue + insert_replicate(SubstitutionBuilder &b, + positive_int degree, + OutputGraphExprValue const &input) { + return insert_replicate_or_reduce(OperatorType::REPLICATE, b, degree, input); +} + +static OutputGraphExprValue insert_reduce(SubstitutionBuilder &b, + positive_int degree, + OutputGraphExprValue const &input) { + return insert_replicate_or_reduce(OperatorType::REDUCTION, b, degree, input); +} + +static OutputGraphExprValue + insert_partition_or_combine(OperatorType op_type, + SubstitutionBuilder &b, + positive_int degree, + ff_dim_t dim, + OutputGraphExprValue const &input) { + + ASSERT(op_type == OperatorType::REPARTITION || + op_type == OperatorType::COMBINE); + + OutputOperatorAttrsAssignment partition_input_expr = + OutputOperatorAttrsAssignment{ + std::nullopt, + { + set_op_type_attr(op_type), + set_attr_to_constant(OperatorAttributeKey::PARALLEL_DEGREE, + OperatorAttributeValue{degree}), + set_attr_to_constant(OperatorAttributeKey::PARALLEL_DIM, + OperatorAttributeValue{dim}), + }}; + + OutputGraphExprValue o_partition_output = insert_single_output_op( + b, partition_input_expr, {{TensorSlotName::INPUT, input}}); + + return o_partition_output; +} + +static OutputGraphExprValue + insert_partition(SubstitutionBuilder &b, + positive_int degree, + ff_dim_t dim, + OutputGraphExprValue const &input) { + + return insert_partition_or_combine( + OperatorType::REPARTITION, b, degree, dim, input); } -Substitution create_combine_concat(nonnegative_int num_inputs, - nonnegative_int num_dims, - nonnegative_int degree) { - NOT_IMPLEMENTED(); +static OutputGraphExprValue insert_combine(SubstitutionBuilder &b, + positive_int degree, + ff_dim_t dim, + OutputGraphExprValue const &input) { + + return insert_partition_or_combine( + OperatorType::COMBINE, b, degree, dim, input); } -Substitution create_replicate_linear_combine(nonnegative_int num_dims, - nonnegative_int degree, +Substitution create_replicate_linear_combine(positive_int num_dims, + positive_int degree, bool use_bias) { SubstitutionBuilder b; @@ -70,72 +197,22 @@ Substitution create_replicate_linear_combine(nonnegative_int num_dims, op_type_equals_constraint(OperatorType::LINEAR), op_attr_key_equals(OperatorAttributeKey::BIAS, OperatorAttributeValue{use_bias}), - op_attr_key_divisible_by(OperatorAttributeKey::OUT_CHANNELS, - nonnegative_int{degree}), + op_attr_key_divisible_by(OperatorAttributeKey::OUT_CHANNELS, degree), }}; - PatternValue p_linear_output = require_only_key( - b.add_pattern_node(linear_pattern, - p_inputs, - { - { - TensorSlotName::OUTPUT, - tensor_attr_pattern_require_num_dims( - nonnegative_int{num_dims}), - }, - }, - "linear"), - TensorSlotName::OUTPUT); + std::string linear_name = "linear"; + PatternValue p_linear_output = insert_single_output_pattern( + b, + linear_pattern, + p_inputs, + /*output_pattern=*/tensor_attr_pattern_require_num_dims(num_dims), + linear_name); - OutputOperatorAttrsAssignment replicate_input_expr = - OutputOperatorAttrsAssignment{ - std::nullopt, - { - set_op_type_attr(OperatorType::REPLICATE), - set_attr_to_constant(OperatorAttributeKey::PARALLEL_DEGREE, - OperatorAttributeValue{degree}), - }}; OutputGraphExprValue o_replicate_input_output = - require_only_key(b.add_output_graph_node( - /*node_expr=*/replicate_input_expr, - /*inputs=*/ - { - { - TensorSlotName::INPUT, - o_input, - }, - }, - /*output_slots=*/ - { - TensorSlotName::OUTPUT, - }), - TensorSlotName::OUTPUT); - - OutputOperatorAttrsAssignment partition_weights_expr = - OutputOperatorAttrsAssignment{ - std::nullopt, - { - set_op_type_attr(OperatorType::REPARTITION), - set_attr_to_constant(OperatorAttributeKey::PARALLEL_DEGREE, - OperatorAttributeValue{degree}), - set_attr_to_constant(OperatorAttributeKey::PARALLEL_DIM, - OperatorAttributeValue{ff_dim_t{1_n}}), - }}; + insert_replicate(b, degree, o_input); + OutputGraphExprValue o_partition_weights_output = - require_only_key(b.add_output_graph_node( - /*node_expr=*/partition_weights_expr, - /*inputs=*/ - { - { - TensorSlotName::INPUT, - o_weight, - }, - }, - /*output_slots=*/ - { - TensorSlotName::OUTPUT, - }), - TensorSlotName::OUTPUT); + insert_partition(b, degree, ff_dim_t{1_n}, o_weight); std::unordered_map o_linear_inputs = { { @@ -149,31 +226,9 @@ Substitution create_replicate_linear_combine(nonnegative_int num_dims, }; if (use_bias) { - OutputOperatorAttrsAssignment partition_bias_expr = - OutputOperatorAttrsAssignment{ - std::nullopt, - { - set_op_type_attr(OperatorType::REPARTITION), - set_attr_to_constant(OperatorAttributeKey::PARALLEL_DEGREE, - OperatorAttributeValue{degree}), - set_attr_to_constant(OperatorAttributeKey::PARALLEL_DIM, - OperatorAttributeValue{ff_dim_t{1_n}}), - }}; OutputGraphExprValue o_partition_bias_output = - require_only_key(b.add_output_graph_node( - /*node_expr=*/partition_bias_expr, - /*inputs=*/ - { - { - TensorSlotName::INPUT, - o_bias.value(), - }, - }, - /*output_slots=*/ - { - TensorSlotName::OUTPUT, - }), - TensorSlotName::OUTPUT); + insert_partition(b, degree, ff_dim_t{1_n}, o_bias.value()); + o_linear_inputs.insert({ TensorSlotName::BIAS, o_partition_bias_output, @@ -181,97 +236,513 @@ Substitution create_replicate_linear_combine(nonnegative_int num_dims, } OutputOperatorAttrsAssignment linear_expr = OutputOperatorAttrsAssignment{ - b.pattern_node_named("linear"), + b.pattern_node_named(linear_name), {}, }; OutputGraphExprValue o_linear_output = - require_only_key(b.add_output_graph_node( - /*node_expr=*/linear_expr, - /*inputs=*/o_linear_inputs, - /*output_slots=*/ - { - TensorSlotName::OUTPUT, - }), - TensorSlotName::OUTPUT); - - OutputOperatorAttrsAssignment combine_expr = OutputOperatorAttrsAssignment{ - std::nullopt, + insert_single_output_op(b, linear_expr, o_linear_inputs); + + ff_dim_t combine_output_dim = ff_dim_t{ + nonnegative_int{num_dims.int_from_positive_int() - 1}, + }; + OutputGraphExprValue o_combine_output = + insert_combine(b, degree, combine_output_dim, o_linear_output); + + b.equate_outputs(p_linear_output, o_combine_output); + + return b.get_substitution(); +} + +Substitution create_partition_linear_combine(positive_int num_dims, + positive_int degree, + bool use_bias) { + SubstitutionBuilder b; + + auto [p_input, o_input] = b.add_input(tensor_attribute_pattern_match_all()); + auto [p_weight, o_weight] = b.add_input(tensor_attribute_pattern_match_all()); + std::unordered_map p_inputs = { { - set_op_type_attr(OperatorType::COMBINE), - set_attr_to_constant(OperatorAttributeKey::PARALLEL_DEGREE, - OperatorAttributeValue{degree}), - set_attr_to_constant( - OperatorAttributeKey::PARALLEL_DIM, - OperatorAttributeValue{ff_dim_t{ - nonnegative_int{num_dims.unwrap_nonnegative() - 1}, - }}), + TensorSlotName::INPUT, + p_input, + }, + { + TensorSlotName::WEIGHT, + p_weight, }, }; + std::optional o_bias = std::nullopt; + if (use_bias) { + std::pair bias = + b.add_input(tensor_attribute_pattern_match_all()); + p_inputs.insert({ + TensorSlotName::BIAS, + bias.first, + }); + o_bias = bias.second; + } + + OperatorAttributePattern linear_pattern = OperatorAttributePattern{{ + op_type_equals_constraint(OperatorType::LINEAR), + op_attr_key_equals(OperatorAttributeKey::BIAS, + OperatorAttributeValue{use_bias}), + op_attr_key_divisible_by(OperatorAttributeKey::OUT_CHANNELS, degree), + }}; + + std::string linear_name = "linear"; + PatternValue p_linear_output = insert_single_output_pattern( + b, + linear_pattern, + p_inputs, + /*output_pattern=*/tensor_attr_pattern_require_num_dims(num_dims), + linear_name); + + OutputGraphExprValue o_partition_input_output = + insert_partition(b, degree, ff_dim_t{0_n}, o_input); + + OutputGraphExprValue o_replicate_weights_output = + insert_replicate(b, degree, o_weight); + + std::unordered_map o_linear_inputs = { + { + TensorSlotName::INPUT, + o_partition_input_output, + }, + { + TensorSlotName::WEIGHT, + o_replicate_weights_output, + }, + }; + + if (use_bias) { + OutputGraphExprValue o_replicate_bias_output = + insert_replicate(b, degree, o_bias.value()); + + o_linear_inputs.insert({ + TensorSlotName::BIAS, + o_replicate_bias_output, + }); + } + + OutputOperatorAttrsAssignment linear_expr = OutputOperatorAttrsAssignment{ + b.pattern_node_named(linear_name), + {}, + }; + OutputGraphExprValue o_linear_output = + insert_single_output_op(b, linear_expr, o_linear_inputs); + + ff_dim_t combine_output_dim = ff_dim_t{ + nonnegative_int{num_dims.int_from_positive_int() - 1}, + }; OutputGraphExprValue o_combine_output = - require_only_key(b.add_output_graph_node( - /*node_expr=*/combine_expr, - /*inputs=*/ - { - { - TensorSlotName::INPUT, - o_linear_output, - }, - }, - /*output_slots=*/ - { - TensorSlotName::OUTPUT, - }), - TensorSlotName::OUTPUT); + insert_combine(b, degree, combine_output_dim, o_linear_output); b.equate_outputs(p_linear_output, o_combine_output); return b.get_substitution(); } -Substitution create_partition_linear_combine(nonnegative_int num_dims, - nonnegative_int degree, - Activation activation, - bool use_bias) { - NOT_IMPLEMENTED(); +Substitution create_partition_conv2d_combine(positive_int num_dims, + positive_int degree) { + ASSERT(num_dims == 4_p); + + SubstitutionBuilder b; + + auto [p_input, o_input] = b.add_input(tensor_attribute_pattern_match_all()); + auto [p_weight, o_weight] = b.add_input(tensor_attribute_pattern_match_all()); + + std::unordered_map p_inputs = { + { + TensorSlotName::INPUT, + p_input, + }, + { + TensorSlotName::FILTER, + p_weight, + }, + }; + + OperatorAttributePattern conv2d_pattern = OperatorAttributePattern{{ + op_type_equals_constraint(OperatorType::CONV2D), + op_attr_key_divisible_by(OperatorAttributeKey::OUT_CHANNELS, degree), + }}; + + std::string conv2d_name = "conv2d"; + PatternValue p_conv2d_output = insert_single_output_pattern( + b, + conv2d_pattern, + p_inputs, + /*output_pattern=*/tensor_attr_pattern_require_num_dims(num_dims), + conv2d_name); + + OutputGraphExprValue o_partition_input_output = + insert_partition(b, degree, ff_dim_t{0_n}, o_input); + + OutputGraphExprValue o_replicate_weights_output = + insert_replicate(b, degree, o_weight); + + std::unordered_map o_conv2d_inputs = { + { + TensorSlotName::INPUT, + o_partition_input_output, + }, + {TensorSlotName::FILTER, o_replicate_weights_output}, + }; + + OutputOperatorAttrsAssignment conv2d_expr = OutputOperatorAttrsAssignment{ + b.pattern_node_named(conv2d_name), + {}, + }; + OutputGraphExprValue o_conv2d_output = + insert_single_output_op(b, conv2d_expr, o_conv2d_inputs); + + OutputGraphExprValue o_combine_output = + insert_combine(b, degree, ff_dim_t{0_n}, o_conv2d_output); + + b.equate_outputs(p_conv2d_output, o_combine_output); + + return b.get_substitution(); } -Substitution create_partition_conv2d_combine(nonnegative_int num_dims, - nonnegative_int degree) { - NOT_IMPLEMENTED(); +Substitution create_partition_attention_combine(positive_int num_heads, + positive_int degree) { + + SubstitutionBuilder b; + + auto [p_query_input, o_query_input] = + b.add_input(tensor_attribute_pattern_match_all()); + auto [p_key_input, o_key_input] = + b.add_input(tensor_attribute_pattern_match_all()); + auto [p_value_input, o_value_input] = + b.add_input(tensor_attribute_pattern_match_all()); + auto [p_weights, o_weights] = + b.add_input(tensor_attribute_pattern_match_all()); + std::unordered_map p_inputs = { + { + TensorSlotName::QUERY, + p_query_input, + }, + { + TensorSlotName::KEY, + p_key_input, + }, + { + TensorSlotName::VALUE, + p_value_input, + }, + { + TensorSlotName::WEIGHT, + p_weights, + }, + }; + + OperatorAttributePattern attention_pattern = OperatorAttributePattern{{ + op_type_equals_constraint(OperatorType::MULTIHEAD_ATTENTION), + op_attr_key_divisible_by(OperatorAttributeKey::OUT_CHANNELS, degree), + op_attr_key_divisible_by(OperatorAttributeKey::NUM_HEADS, num_heads), + }}; + + std::string attention_name = "attention"; + PatternValue p_attention_output = insert_single_output_pattern( + b, + attention_pattern, + p_inputs, + /*output_pattern=*/tensor_attr_pattern_require_num_dims(3_p), + attention_name); + + OutputGraphExprValue o_partition_query_input_output = + insert_partition(b, degree, ff_dim_t{0_n}, o_query_input); + + OutputGraphExprValue o_partition_key_input_output = + insert_partition(b, degree, ff_dim_t{0_n}, o_key_input); + + OutputGraphExprValue o_partition_value_input_output = + insert_partition(b, degree, ff_dim_t{0_n}, o_value_input); + + OutputGraphExprValue o_replicate_weight_output = + insert_replicate(b, degree, o_weights); + + std::unordered_map o_attention_inputs = + { + { + TensorSlotName::QUERY, + o_partition_query_input_output, + }, + { + TensorSlotName::KEY, + o_partition_key_input_output, + }, + { + TensorSlotName::VALUE, + o_partition_value_input_output, + }, + { + TensorSlotName::WEIGHT, + o_replicate_weight_output, + }, + }; + + OutputOperatorAttrsAssignment attention_expr = OutputOperatorAttrsAssignment{ + b.pattern_node_named(attention_name), + {}, + }; + OutputGraphExprValue o_attention_output = + insert_single_output_op(b, attention_expr, o_attention_inputs); + + OutputGraphExprValue o_combine_output = + insert_combine(b, degree, ff_dim_t{0_n}, o_attention_output); + + b.equate_outputs(p_attention_output, o_combine_output); + + return b.get_substitution(); } -Substitution create_partition_attention_combine(nonnegative_int num_heads, - nonnegative_int degree) { - NOT_IMPLEMENTED(); +Substitution create_replicate_attention_reduce(positive_int num_heads, + positive_int degree) { + + SubstitutionBuilder b; + + auto [p_query_input, o_query_input] = + b.add_input(tensor_attribute_pattern_match_all()); + auto [p_key_input, o_key_input] = + b.add_input(tensor_attribute_pattern_match_all()); + auto [p_value_input, o_value_input] = + b.add_input(tensor_attribute_pattern_match_all()); + auto [p_weights, o_weights] = + b.add_input(tensor_attribute_pattern_match_all()); + + std::unordered_map p_inputs = { + { + TensorSlotName::QUERY, + p_query_input, + }, + { + TensorSlotName::KEY, + p_key_input, + }, + { + TensorSlotName::VALUE, + p_value_input, + }, + { + TensorSlotName::WEIGHT, + p_weights, + }, + }; + + OperatorAttributePattern attention_pattern = OperatorAttributePattern{{ + op_type_equals_constraint(OperatorType::MULTIHEAD_ATTENTION), + op_attr_key_divisible_by(OperatorAttributeKey::OUT_CHANNELS, degree), + op_attr_key_divisible_by(OperatorAttributeKey::NUM_HEADS, num_heads), + }}; + + std::string attention_name = "attention"; + PatternValue p_attention_output = insert_single_output_pattern( + b, + attention_pattern, + p_inputs, + /*output_pattern=*/tensor_attr_pattern_require_num_dims(3_p), + attention_name); + + OutputGraphExprValue o_replicate_query_input_output = + insert_replicate(b, degree, o_query_input); + + OutputGraphExprValue o_replicate_key_input_output = + insert_replicate(b, degree, o_key_input); + + OutputGraphExprValue o_replicate_value_input_output = + insert_replicate(b, degree, o_value_input); + + OutputGraphExprValue o_partition_weight_output = + insert_partition(b, degree, ff_dim_t{1_n}, o_weights); + + std::unordered_map o_attention_inputs = + { + { + TensorSlotName::QUERY, + o_replicate_query_input_output, + }, + { + TensorSlotName::KEY, + o_replicate_key_input_output, + }, + { + TensorSlotName::VALUE, + o_replicate_value_input_output, + }, + { + TensorSlotName::WEIGHT, + o_partition_weight_output, + }, + }; + + OutputOperatorAttrsAssignment attention_expr = OutputOperatorAttrsAssignment{ + b.pattern_node_named(attention_name), + {}, + }; + OutputGraphExprValue o_attention_output = + insert_single_output_op(b, attention_expr, o_attention_inputs); + + OutputGraphExprValue o_reduce_output = + insert_reduce(b, degree, o_attention_output); + + b.equate_outputs(p_attention_output, o_reduce_output); + + return b.get_substitution(); } -Substitution create_replicate_attention_reduce(nonnegative_int num_heads, - nonnegative_int degree) { - NOT_IMPLEMENTED(); +Substitution create_partition_softmax_combine(ff_dim_t softmax_dim, + ff_dim_t partition_dim, + positive_int degree) { + ASSERT(partition_dim != softmax_dim); + + SubstitutionBuilder b; + + auto [p_input, o_input] = b.add_input(tensor_attribute_pattern_match_all()); + std::unordered_map p_inputs = { + { + TensorSlotName::INPUT, + p_input, + }, + }; + + OperatorAttributePattern softmax_pattern = OperatorAttributePattern{{ + op_type_equals_constraint(OperatorType::SOFTMAX), + op_attr_key_divisible_by(OperatorAttributeKey::OUT_CHANNELS, degree), + op_attr_key_divisible_by(OperatorAttributeKey::SOFTMAX_DIM, + positive_int{softmax_dim.value}), + }}; + + std::string softmax_name = "softmax"; + PatternValue p_softmax_output = insert_single_output_pattern( + b, + softmax_pattern, + p_inputs, + /*output_pattern=*/tensor_attribute_pattern_match_all(), + softmax_name); + + OutputGraphExprValue o_partition_input_output = + insert_partition(b, degree, partition_dim, o_input); + + std::unordered_map o_softmax_inputs = { + { + TensorSlotName::INPUT, + o_partition_input_output, + }, + }; + + OutputOperatorAttrsAssignment softmax_expr = OutputOperatorAttrsAssignment{ + b.pattern_node_named(softmax_name), + {}, + }; + OutputGraphExprValue o_softmax_output = + insert_single_output_op(b, softmax_expr, o_softmax_inputs); + + OutputGraphExprValue o_combine_output = + insert_combine(b, degree, partition_dim, o_softmax_output); + + b.equate_outputs(p_softmax_output, o_combine_output); + + return b.get_substitution(); } Substitution create_partition_add_combine(ff_dim_t parallel_dim, - nonnegative_int degree) { - NOT_IMPLEMENTED(); + positive_int degree) { + SubstitutionBuilder b; + + auto [p_input1, o_input1] = b.add_input(tensor_attribute_pattern_match_all()); + auto [p_input2, o_input2] = b.add_input(tensor_attribute_pattern_match_all()); + + std::unordered_map p_inputs = { + { + TensorSlotName::LHS_INPUT, + p_input1, + }, + { + TensorSlotName::RHS_INPUT, + p_input2, + }, + }; + + OperatorAttributePattern add_pattern = OperatorAttributePattern{{ + op_type_equals_constraint(OperatorType::EW_ADD), + op_attr_key_divisible_by(OperatorAttributeKey::OUT_CHANNELS, degree), + }}; + + std::string add_name = "add"; + PatternValue p_add_output = insert_single_output_pattern( + b, + add_pattern, + p_inputs, + /*output_pattern=*/tensor_attribute_pattern_match_all(), + add_name); + + OutputGraphExprValue o_partition_input1_output = + insert_partition(b, degree, parallel_dim, o_input1); + OutputGraphExprValue o_partition_input2_output = + insert_partition(b, degree, parallel_dim, o_input2); + + std::unordered_map o_add_inputs = { + { + TensorSlotName::LHS_INPUT, + o_partition_input1_output, + }, + { + TensorSlotName::RHS_INPUT, + o_partition_input2_output, + }, + }; + + OutputOperatorAttrsAssignment add_expr = OutputOperatorAttrsAssignment{ + b.pattern_node_named(add_name), + {}, + }; + OutputGraphExprValue o_add_output = + insert_single_output_op(b, add_expr, o_add_inputs); + + OutputGraphExprValue o_combine_output = + insert_combine(b, degree, parallel_dim, o_add_output); + + b.equate_outputs(p_add_output, o_combine_output); + + return b.get_substitution(); } Substitution create_partition_relu_combine(ff_dim_t parallel_dim, - nonnegative_int degree) { - NOT_IMPLEMENTED(); -} + positive_int degree) { + SubstitutionBuilder b; -Substitution create_partition_concat_combine(nonnegative_int num_inputs, - ff_dim_t concat_dim, - ff_dim_t parallel_dim, - nonnegative_int degree) { - NOT_IMPLEMENTED(); -} + auto [p_input, o_input] = b.add_input(tensor_attribute_pattern_match_all()); -Substitution create_partition_softmax_combine(ff_dim_t softmax_dim, - ff_dim_t partition_dim, - nonnegative_int degree) { - NOT_IMPLEMENTED(); + OperatorAttributePattern relu_pattern = OperatorAttributePattern{{ + op_type_equals_constraint(OperatorType::RELU), + op_attr_key_divisible_by(OperatorAttributeKey::OUT_CHANNELS, degree), + }}; + + std::string relu_name = "relu"; + PatternValue p_relu_output = insert_single_output_pattern( + b, + relu_pattern, + {{TensorSlotName::INPUT, p_input}}, + /*output_pattern=*/tensor_attribute_pattern_match_all(), + relu_name); + + OutputGraphExprValue o_partition_input_output = + insert_partition(b, degree, parallel_dim, o_input); + + OutputOperatorAttrsAssignment relu_expr = OutputOperatorAttrsAssignment{ + b.pattern_node_named(relu_name), + {}, + }; + OutputGraphExprValue o_relu_output = insert_single_output_op( + b, relu_expr, {{TensorSlotName::INPUT, o_partition_input_output}}); + + OutputGraphExprValue o_combine_output = + insert_combine(b, degree, parallel_dim, o_relu_output); + + b.equate_outputs(p_relu_output, o_combine_output); + + return b.get_substitution(); } Substitution create_fuse_linear_activation(Activation activation) { @@ -288,78 +759,64 @@ Substitution create_fuse_linear_activation(Activation activation) { OperatorAttributeKey::ACTIVATION, OperatorAttributeValue{std::optional{std::nullopt}}), }}; - PatternValue p_mm_output = - require_only_key(b.add_pattern_node( - /*node_expr=*/mm_pattern, - /*inputs=*/ - { - { - TensorSlotName::INPUT, - p_input, - }, - { - TensorSlotName::WEIGHT, - p_weight, - }, - }, - /*output_patterns=*/ - { - { - TensorSlotName::OUTPUT, - tensor_attribute_pattern_match_all(), - }, - }, - /*name=*/"mm"), - TensorSlotName::OUTPUT); + + std::string mm_name = "mm"; + PatternValue p_mm_output = insert_single_output_pattern( + b, + mm_pattern, + /*inputs=*/ + { + { + TensorSlotName::INPUT, + p_input, + }, + { + TensorSlotName::WEIGHT, + p_weight, + }, + }, + /*output_pattern=*/tensor_attribute_pattern_match_all(), + mm_name); OperatorAttributePattern relu_pattern = OperatorAttributePattern{{ op_type_equals_constraint(OperatorType::RELU), }}; - PatternValue p_relu_output = - require_only_key(b.add_pattern_node( - /*node_expr=*/relu_pattern, - /*inputs=*/ - { - { - TensorSlotName::INPUT, - p_mm_output, - }, - }, - /*output_patterns=*/ - { - { - TensorSlotName::OUTPUT, - tensor_attribute_pattern_match_all(), - }, - }, - /*name=*/"relu"), - TensorSlotName::OUTPUT); + + std::string relu_name = "relu"; + PatternValue p_relu_output = insert_single_output_pattern( + b, + relu_pattern, + /*inputs=*/ + { + { + TensorSlotName::INPUT, + p_mm_output, + }, + }, + /*output_pattern=*/tensor_attribute_pattern_match_all(), + relu_name); OutputOperatorAttrsAssignment fused_node_expr = OutputOperatorAttrsAssignment{ - b.pattern_node_named("mm"), + b.pattern_node_named(mm_name), { set_attr_to_constant(OperatorAttributeKey::ACTIVATION, OperatorAttributeValue{activation}), }}; + OutputGraphExprValue o_fused_node_output = - require_only_key(b.add_output_graph_node( - /*node_expr=*/fused_node_expr, - /*inputs=*/ - { - { - TensorSlotName::INPUT, - o_input, - }, - { - TensorSlotName::WEIGHT, - o_weight, - }, - }, - /*output_slots=*/ - { - TensorSlotName::OUTPUT, - }), - TensorSlotName::OUTPUT); + insert_single_output_op(b, + fused_node_expr, + /*inputs=*/ + { + { + TensorSlotName::INPUT, + o_input, + }, + { + TensorSlotName::WEIGHT, + o_weight, + }, + }); b.equate_outputs(p_relu_output, o_fused_node_output); diff --git a/lib/substitutions/test/src/substitutions/unity_substitution_set.cc b/lib/substitutions/test/src/substitutions/unity_substitution_set.cc index ea8c8529ba..df7f28538e 100644 --- a/lib/substitutions/test/src/substitutions/unity_substitution_set.cc +++ b/lib/substitutions/test/src/substitutions/unity_substitution_set.cc @@ -1,8 +1,203 @@ #include "substitutions/unity_substitution_set.h" +#include "op-attrs/computation_graph_op_attrs.h" +#include "op-attrs/operator_type.h" +#include "op-attrs/ops/attention.h" +#include "op-attrs/ops/combine.h" +#include "op-attrs/ops/conv_2d.h" +#include "op-attrs/ops/element_binary.h" +#include "op-attrs/ops/element_unary.h" +#include "op-attrs/ops/linear.h" +#include "op-attrs/ops/repartition.h" +#include "op-attrs/ops/replicate.h" +#include "op-attrs/parallel_tensor_shape.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph_builder.h" +#include "substitutions/apply_substitution/apply_substitution.h" +#include "substitutions/open_parallel_tensor_guid_t.h" +#include "substitutions/pcg_pattern.h" +#include "substitutions/sub_parallel_computation_graph.h" +#include "substitutions/substitution_builder.h" +#include "utils/containers/get_only.h" +#include "utils/containers/require_only_key.h" #include using namespace ::FlexFlow; +template +static ParallelLayerAttrs make_layer_attrs( + T const &op_attrs, + std::optional const &maybe_name = std::nullopt) { + return ParallelLayerAttrs{ + /*op_attrs=*/PCGOperatorAttrs{op_attrs}, + /*name=*/maybe_name, + }; +}; + +parallel_tensor_guid_t + get_single_output(ParallelLayerAddedResult const &added) { + return require_only_key(added.outputs, TensorSlotName::OUTPUT); +} + +parallel_tensor_guid_t add_single_output_layer( + ParallelComputationGraph &pcg, + ParallelLayerAttrs const &layer_attrs, + std::unordered_map const &inputs, + std::unordered_map const &weights, + std::optional> const + &outputs = std::nullopt) { + + return get_single_output( + add_parallel_layer(pcg, layer_attrs, inputs, weights, outputs)); +} + +parallel_tensor_guid_t add_input_layer(ParallelComputationGraph &pcg, + TensorShape const &tensor_shape) { + + return get_single_output(pcg_add_input_layer(pcg, tensor_shape)); +} + +parallel_tensor_guid_t add_weight_layer(ParallelComputationGraph &pcg, + TensorShape const &tensor_shape) { + + WeightAttrs weight_attrs = WeightAttrs{ + /*tensor_shape=*/tensor_shape, + /*initializer=*/InitializerAttrs{ZeroInitializerAttrs{}}, + }; + + return add_single_output_layer(pcg, make_layer_attrs(weight_attrs), {}, {}); +} + +parallel_tensor_guid_t + add_replicate_layer(ParallelComputationGraph &pcg, + positive_int degree, + parallel_tensor_guid_t const &t_input) { + + ReplicateAttrs replicate_attrs = ReplicateAttrs{ + /*replicate_degree=*/degree, + }; + + return add_single_output_layer(pcg, + make_layer_attrs(replicate_attrs), + {{TensorSlotName::INPUT, t_input}}, + {}); +} + +parallel_tensor_guid_t + add_reduction_layer(ParallelComputationGraph &pcg, + positive_int degree, + parallel_tensor_guid_t const &t_input) { + + ReductionAttrs reduction_attrs = ReductionAttrs{ + /*reduction_degree=*/degree, + }; + + return add_single_output_layer(pcg, + make_layer_attrs(reduction_attrs), + {{TensorSlotName::INPUT, t_input}}, + {}); +} + +parallel_tensor_guid_t + add_partition_layer(ParallelComputationGraph &pcg, + ff_dim_t dim, + positive_int degree, + parallel_tensor_guid_t const &t_input) { + + RepartitionAttrs partition_attrs = RepartitionAttrs{ + /*repartition_dim=*/dim, + /*repartition_degree=*/degree, + }; + + return add_single_output_layer(pcg, + make_layer_attrs(partition_attrs), + {{TensorSlotName::INPUT, t_input}}, + {}); +} + +parallel_tensor_guid_t + add_combine_layer(ParallelComputationGraph &pcg, + ff_dim_t dim, + positive_int degree, + parallel_tensor_guid_t const &t_input) { + + CombineAttrs partition_attrs = CombineAttrs{ + /*combine_dim=*/dim, + /*combine_degree=*/degree, + }; + + return add_single_output_layer(pcg, + make_layer_attrs(partition_attrs), + {{TensorSlotName::INPUT, t_input}}, + {}); +} + +parallel_tensor_guid_t add_linear_layer( + ParallelComputationGraph &pcg, + LinearAttrs const &linear_attrs, + parallel_tensor_guid_t const &t_input, + parallel_tensor_guid_t const &t_weight, + std::optional const &t_bias = std::nullopt, + std::optional const &name = std::nullopt) { + + ASSERT(t_bias.has_value() == linear_attrs.use_bias); + + std::unordered_map weights = { + {TensorSlotName::WEIGHT, t_weight}, + }; + + if (t_bias.has_value()) { + weights.insert({TensorSlotName::BIAS, t_bias.value()}); + } + + return add_single_output_layer(pcg, + make_layer_attrs(linear_attrs, name), + {{TensorSlotName::INPUT, t_input}}, + weights); +} + +parallel_tensor_guid_t + add_attention_layer(ParallelComputationGraph &pcg, + MultiHeadAttentionAttrs const &attn_attrs, + parallel_tensor_guid_t const &t_query, + parallel_tensor_guid_t const &t_key, + parallel_tensor_guid_t const &t_value, + parallel_tensor_guid_t const &t_weights, + std::optional const &name = std::nullopt) { + + return add_single_output_layer(pcg, + make_layer_attrs(attn_attrs, name), + { + {TensorSlotName::QUERY, t_query}, + {TensorSlotName::KEY, t_key}, + {TensorSlotName::VALUE, t_value}, + }, + {{TensorSlotName::WEIGHT, t_weights}}); +} + +parallel_tensor_guid_t add_conv2d_layer( + ParallelComputationGraph &pcg, + Conv2DAttrs const &conv2d_attrs, + parallel_tensor_guid_t const &t_input, + parallel_tensor_guid_t const &t_filter, + std::optional const &bias = std::nullopt, + std::optional const &name = std::nullopt) { + + ASSERT(bias.has_value() == conv2d_attrs.use_bias); + + std::unordered_map weights = { + {TensorSlotName::FILTER, t_filter}, + }; + + if (bias.has_value()) { + weights.insert({TensorSlotName::BIAS, bias.value()}); + } + + return add_single_output_layer(pcg, + make_layer_attrs(conv2d_attrs, name), + {{TensorSlotName::INPUT, t_input}}, + weights); +} + TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("get_substitution_set") { MachineComputeSpecification machine_spec = MachineComputeSpecification{ @@ -13,6 +208,1151 @@ TEST_SUITE(FF_TEST_SUITE) { std::vector result = get_substitution_set(machine_spec); - CHECK(result.size() == 36); + CHECK(result.size() == 248); + } + + TEST_CASE("create_replicate_linear_combine, use_bias = false") { + positive_int num_dims = 1_p; + positive_int degree = 2_p; + std::string linear_match = "linear_match"; + + Substitution sub = create_replicate_linear_combine(num_dims, degree, false); + + TensorShape input_shape = TensorShape{ + TensorDims{ + FFOrdered{ + 10_p, + 12_p, + }, + }, + DataType::FLOAT, + }; + + LinearAttrs linear_attrs = LinearAttrs{ + /*out_channels=*/12_p, + /*use_bias=*/false, + /*data_type=*/DataType::FLOAT, + /*activation=*/std::nullopt, + /*regularizer=*/std::nullopt, + }; + + ReplicateAttrs replicate_input_attrs = ReplicateAttrs{ + /*replicate_degree=*/degree, + }; + + TensorShape projection_weight_shape = + throw_if_unexpected(get_projection_shape(linear_attrs, input_shape)); + + RepartitionAttrs partition_projection_attrs = RepartitionAttrs{ + /*repartition_dim=*/ff_dim_t{1_n}, + /*repartition_degree=*/degree, + }; + + ff_dim_t combine_dim = + ff_dim_t{nonnegative_int{num_dims.int_from_positive_int() - 1}}; + + SubParallelComputationGraph original_pcg = [&] { + ParallelComputationGraph pcg = empty_parallel_computation_graph(); + + parallel_tensor_guid_t t_input = add_input_layer(pcg, input_shape); + + parallel_tensor_guid_t t_projection_weight = + add_weight_layer(pcg, projection_weight_shape); + + parallel_tensor_guid_t t_linear = add_linear_layer(pcg, + linear_attrs, + t_input, + t_projection_weight, + /*bias=*/std::nullopt, + linear_match); + + return sub_pcg_from_full_pcg(pcg); + }(); + + PCGPatternMatch match = [&] { + parallel_layer_guid_t match_layer = + get_parallel_layer_by_name(original_pcg, linear_match); + open_parallel_tensor_guid_t match_layer_input_activations = + get_layer_inputs(original_pcg, match_layer).at(TensorSlotName::INPUT); + open_parallel_tensor_guid_t match_layer_input_weights = + get_layer_inputs(original_pcg, match_layer) + .at(TensorSlotName::WEIGHT); + + return PCGPatternMatch{ + bidict{ + {PatternNode{Node{0}}, match_layer}, + }, + std::unordered_map{ + { + PatternInput{KwargDataflowGraphInput{0}}, + match_layer_input_activations, + }, + { + PatternInput{KwargDataflowGraphInput{2}}, + match_layer_input_weights, + }}, + }; + }(); + + SubParallelComputationGraph result = + apply_substitution(original_pcg, sub, match); + + SubParallelComputationGraph correct = [&] { + ParallelComputationGraph pcg = empty_parallel_computation_graph(); + + parallel_tensor_guid_t t_replicated_input = + add_replicate_layer(pcg, degree, add_input_layer(pcg, input_shape)); + + parallel_tensor_guid_t t_partitioned_projection_weight = + add_partition_layer(pcg, + ff_dim_t{1_n}, + degree, + add_weight_layer(pcg, projection_weight_shape)); + + parallel_tensor_guid_t t_replicated_linear = + add_linear_layer(pcg, + linear_attrs, + t_replicated_input, + t_partitioned_projection_weight); + + parallel_tensor_guid_t t_combine = + add_combine_layer(pcg, combine_dim, degree, t_replicated_input); + + return sub_pcg_from_full_pcg(pcg); + }(); + + CHECK(sub_pcgs_are_isomorphic(result, correct)); + } + + TEST_CASE("create_replicate_linear_combine, use_bias = true") { + positive_int num_dims = 1_p; + positive_int degree = 2_p; + std::string linear_match = "linear_match"; + + Substitution sub = create_replicate_linear_combine(num_dims, degree, true); + + TensorShape input_shape = TensorShape{ + TensorDims{ + FFOrdered{ + 10_p, + 12_p, + }, + }, + DataType::FLOAT, + }; + + LinearAttrs linear_attrs = LinearAttrs{ + /*out_channels=*/12_p, + /*use_bias=*/true, + /*data_type=*/DataType::FLOAT, + /*activation=*/std::nullopt, + /*regularizer=*/std::nullopt, + }; + + TensorShape projection_weight_shape = + throw_if_unexpected(get_projection_shape(linear_attrs, input_shape)); + + TensorShape bias_shape = + throw_if_unexpected(get_bias_shape(linear_attrs, input_shape)); + + ff_dim_t combine_dim = + ff_dim_t{nonnegative_int{num_dims.int_from_positive_int() - 1}}; + + SubParallelComputationGraph original_pcg = [&] { + ParallelComputationGraph pcg = empty_parallel_computation_graph(); + + parallel_tensor_guid_t t_input = add_input_layer(pcg, input_shape); + + parallel_tensor_guid_t t_projection_weight = + add_weight_layer(pcg, projection_weight_shape); + + parallel_tensor_guid_t t_bias = add_weight_layer(pcg, bias_shape); + + parallel_tensor_guid_t t_linear = add_linear_layer( + pcg, linear_attrs, t_input, t_projection_weight, t_bias); + + return sub_pcg_from_full_pcg(pcg); + }(); + + PCGPatternMatch match = [&] { + parallel_layer_guid_t match_layer = + get_parallel_layer_by_name(original_pcg, linear_match); + open_parallel_tensor_guid_t match_layer_input_activations = + get_layer_inputs(original_pcg, match_layer).at(TensorSlotName::INPUT); + open_parallel_tensor_guid_t match_layer_input_weights = + get_layer_inputs(original_pcg, match_layer) + .at(TensorSlotName::WEIGHT); + open_parallel_tensor_guid_t match_layer_input_bias = + get_layer_inputs(original_pcg, match_layer) + .at(TensorSlotName::OUTPUT); + + return PCGPatternMatch{ + bidict{ + {PatternNode{Node{0}}, match_layer}, + }, + std::unordered_map{ + { + PatternInput{KwargDataflowGraphInput{0}}, + match_layer_input_activations, + }, + { + PatternInput{KwargDataflowGraphInput{2}}, + match_layer_input_weights, + }, + { + PatternInput{KwargDataflowGraphInput{4}}, + match_layer_input_bias, + }}, + }; + }(); + + SubParallelComputationGraph result = + apply_substitution(original_pcg, sub, match); + + SubParallelComputationGraph correct = [&] { + ParallelComputationGraph pcg = empty_parallel_computation_graph(); + + parallel_tensor_guid_t t_replicated_input = + add_replicate_layer(pcg, degree, add_input_layer(pcg, input_shape)); + + parallel_tensor_guid_t t_partitioned_projection_weight = + add_partition_layer(pcg, + ff_dim_t{1_n}, + degree, + add_weight_layer(pcg, projection_weight_shape)); + + parallel_tensor_guid_t t_partitioned_bias = add_partition_layer( + pcg, ff_dim_t{1_n}, degree, add_weight_layer(pcg, bias_shape)); + + parallel_tensor_guid_t t_replicated_linear = + add_linear_layer(pcg, + linear_attrs, + t_replicated_linear, + t_partitioned_projection_weight, + t_partitioned_bias); + + parallel_tensor_guid_t t_combine = + add_combine_layer(pcg, combine_dim, degree, t_replicated_linear); + + return sub_pcg_from_full_pcg(pcg); + }(); + + CHECK(sub_pcgs_are_isomorphic(result, correct)); + } + + TEST_CASE("create_partition_linear_combine, use_bias = false") { + positive_int num_dims = 1_p; + positive_int degree = 2_p; + std::string linear_match = "linear_match"; + + Substitution sub = create_partition_linear_combine(num_dims, degree, false); + + TensorShape input_shape = TensorShape{ + TensorDims{ + FFOrdered{ + 10_p, + 12_p, + }, + }, + DataType::FLOAT, + }; + + LinearAttrs linear_attrs = LinearAttrs{ + /*out_channels=*/12_p, + /*use_bias=*/false, + /*data_type=*/DataType::FLOAT, + /*activation=*/std::nullopt, + /*regularizer=*/std::nullopt, + }; + + TensorShape projection_weight_shape = + throw_if_unexpected(get_projection_shape(linear_attrs, input_shape)); + + ff_dim_t combine_dim = + ff_dim_t{nonnegative_int{num_dims.int_from_positive_int() - 1}}; + + SubParallelComputationGraph original_pcg = [&] { + ParallelComputationGraph pcg = empty_parallel_computation_graph(); + + parallel_tensor_guid_t t_input = add_input_layer(pcg, input_shape); + + parallel_tensor_guid_t t_projection_weight = + add_weight_layer(pcg, projection_weight_shape); + + parallel_tensor_guid_t t_linear = add_linear_layer(pcg, + linear_attrs, + t_input, + t_projection_weight, + /*bias=*/std::nullopt, + linear_match); + + return sub_pcg_from_full_pcg(pcg); + }(); + + PCGPatternMatch match = [&] { + parallel_layer_guid_t match_layer = + get_parallel_layer_by_name(original_pcg, linear_match); + open_parallel_tensor_guid_t match_layer_input_activations = + get_layer_inputs(original_pcg, match_layer).at(TensorSlotName::INPUT); + open_parallel_tensor_guid_t match_layer_input_weights = + get_layer_inputs(original_pcg, match_layer) + .at(TensorSlotName::WEIGHT); + + return PCGPatternMatch{ + bidict{ + {PatternNode{Node{0}}, match_layer}, + }, + std::unordered_map{ + { + PatternInput{KwargDataflowGraphInput{0}}, + match_layer_input_activations, + }, + { + PatternInput{KwargDataflowGraphInput{2}}, + match_layer_input_weights, + }}, + }; + }(); + + SubParallelComputationGraph result = + apply_substitution(original_pcg, sub, match); + + SubParallelComputationGraph correct = [&] { + ParallelComputationGraph pcg = empty_parallel_computation_graph(); + + parallel_tensor_guid_t t_input = add_input_layer(pcg, input_shape); + + parallel_tensor_guid_t t_partitioned_input = add_partition_layer( + pcg, ff_dim_t{0_n}, degree, add_input_layer(pcg, input_shape)); + + parallel_tensor_guid_t t_replicated_projection_weight = + add_replicate_layer( + pcg, degree, add_weight_layer(pcg, projection_weight_shape)); + + parallel_tensor_guid_t t_partitioned_linear = + add_linear_layer(pcg, + linear_attrs, + t_partitioned_input, + t_replicated_projection_weight); + + parallel_tensor_guid_t t_combine = + add_combine_layer(pcg, combine_dim, degree, t_partitioned_input); + + return sub_pcg_from_full_pcg(pcg); + }(); + + CHECK(sub_pcgs_are_isomorphic(result, correct)); + } + + TEST_CASE("create_partition_linear_combine, use_bias = true") { + positive_int num_dims = 1_p; + positive_int degree = 2_p; + std::string linear_match = "linear_match"; + + Substitution sub = create_partition_linear_combine(num_dims, degree, true); + + TensorShape input_shape = TensorShape{ + TensorDims{ + FFOrdered{ + 10_p, + 12_p, + }, + }, + DataType::FLOAT, + }; + + LinearAttrs linear_attrs = LinearAttrs{ + /*out_channels=*/12_p, + /*use_bias=*/true, + /*data_type=*/DataType::FLOAT, + /*activation=*/std::nullopt, + /*regularizer=*/std::nullopt, + }; + + TensorShape projection_weight_shape = + throw_if_unexpected(get_projection_shape(linear_attrs, input_shape)); + + TensorShape bias_shape = + throw_if_unexpected(get_bias_shape(linear_attrs, input_shape)); + + ff_dim_t combine_dim = + ff_dim_t{nonnegative_int{num_dims.int_from_positive_int() - 1}}; + + SubParallelComputationGraph original_pcg = [&] { + ParallelComputationGraph pcg = empty_parallel_computation_graph(); + + parallel_tensor_guid_t t_input = add_input_layer(pcg, input_shape); + parallel_tensor_guid_t t_projection_weight = + add_weight_layer(pcg, projection_weight_shape); + parallel_tensor_guid_t t_bias = add_weight_layer(pcg, bias_shape); + + parallel_tensor_guid_t t_linear = add_linear_layer(pcg, + linear_attrs, + t_input, + t_projection_weight, + t_bias, + linear_match); + + return sub_pcg_from_full_pcg(pcg); + }(); + + PCGPatternMatch match = [&] { + parallel_layer_guid_t match_layer = + get_parallel_layer_by_name(original_pcg, linear_match); + + open_parallel_tensor_guid_t match_layer_input_activations = + get_layer_inputs(original_pcg, match_layer).at(TensorSlotName::INPUT); + open_parallel_tensor_guid_t match_layer_input_weights = + get_layer_inputs(original_pcg, match_layer) + .at(TensorSlotName::WEIGHT); + open_parallel_tensor_guid_t match_layer_input_bias = + get_layer_inputs(original_pcg, match_layer).at(TensorSlotName::BIAS); + + return PCGPatternMatch{ + bidict{ + {PatternNode{Node{0}}, match_layer}, + }, + std::unordered_map{ + { + PatternInput{KwargDataflowGraphInput{0}}, + match_layer_input_activations, + }, + { + PatternInput{KwargDataflowGraphInput{2}}, + match_layer_input_weights, + }, + { + PatternInput{KwargDataflowGraphInput{4}}, + match_layer_input_bias, + }}, + }; + }(); + + SubParallelComputationGraph result = + apply_substitution(original_pcg, sub, match); + + SubParallelComputationGraph correct = [&] { + ParallelComputationGraph pcg = empty_parallel_computation_graph(); + + parallel_tensor_guid_t t_partitioned_input = add_partition_layer( + pcg, ff_dim_t{0_n}, degree, add_input_layer(pcg, input_shape)); + + parallel_tensor_guid_t t_replicated_projection_weight = + add_replicate_layer( + pcg, degree, add_weight_layer(pcg, projection_weight_shape)); + + parallel_tensor_guid_t t_replicated_bias = + add_replicate_layer(pcg, degree, add_weight_layer(pcg, bias_shape)); + + parallel_tensor_guid_t t_partitioned_linear = + add_linear_layer(pcg, + linear_attrs, + t_partitioned_input, + t_replicated_projection_weight, + t_replicated_bias); + + parallel_tensor_guid_t t_combine = + add_combine_layer(pcg, combine_dim, degree, t_partitioned_linear); + + return sub_pcg_from_full_pcg(pcg); + }(); + + CHECK(sub_pcgs_are_isomorphic(result, correct)); + } + + TEST_CASE("create_partition_conv2d_combine") { + positive_int outChannels = 6_p; + positive_int kernelH = 5_p; + positive_int kernelW = 4_p; + positive_int strideH = 3_p; + positive_int strideW = 2_p; + nonnegative_int paddingH = 1_n; + nonnegative_int paddingW = 0_n; + positive_int num_dims = 4_p; + positive_int degree = 2_p; + std::string conv2d_match = "conv2d_match"; + + Substitution sub = create_partition_conv2d_combine(num_dims, degree); + + TensorShape input_shape = TensorShape{ + TensorDims{ + FFOrdered{ + 12_p, + 3_p, + 12_p, + 10_p, + }, + }, + DataType::FLOAT, + }; + + Conv2DAttrs conv2d_attrs = Conv2DAttrs{ + /*outChannels=*/outChannels, + /*kernelH=*/kernelH, + /*kernelW=*/kernelW, + /*strideH=*/strideH, + /*strideW=*/strideW, + /*paddingH=*/paddingH, + /*paddingW=*/paddingW, + /*groups=*/1_p, + /*activation=*/std::nullopt, + /*use_bias=*/false, + }; + + SubParallelComputationGraph original_pcg = [&] { + ParallelComputationGraph pcg = empty_parallel_computation_graph(); + + parallel_tensor_guid_t t_input = add_input_layer(pcg, input_shape); + + TensorShape casted_input_shape = + get_reduced_shape(get_parallel_tensor_shape(pcg, t_input)); + + TensorShape projection_weight_shape = + get_weight_shapes(conv2d_attrs, casted_input_shape) + .at(TensorSlotName::FILTER); + + parallel_tensor_guid_t t_projection_weight = + add_weight_layer(pcg, projection_weight_shape); + + parallel_tensor_guid_t t_conv = add_conv2d_layer(pcg, + conv2d_attrs, + t_input, + t_projection_weight, + /*bias=*/std::nullopt, + conv2d_match); + + return sub_pcg_from_full_pcg(pcg); + }(); + + PCGPatternMatch match = [&] { + parallel_layer_guid_t match_layer = + get_parallel_layer_by_name(original_pcg, conv2d_match); + open_parallel_tensor_guid_t match_layer_input_activations = + get_layer_inputs(original_pcg, match_layer).at(TensorSlotName::INPUT); + open_parallel_tensor_guid_t match_layer_input_weights = + get_layer_inputs(original_pcg, match_layer) + .at(TensorSlotName::FILTER); + + return PCGPatternMatch{ + bidict{ + {PatternNode{Node{0}}, match_layer}, + }, + std::unordered_map{ + { + PatternInput{KwargDataflowGraphInput{0}}, + match_layer_input_activations, + }, + { + PatternInput{KwargDataflowGraphInput{2}}, + match_layer_input_weights, + }}, + }; + }(); + + SubParallelComputationGraph result = + apply_substitution(original_pcg, sub, match); + + SubParallelComputationGraph correct = [&] { + ParallelComputationGraph pcg = empty_parallel_computation_graph(); + + parallel_tensor_guid_t t_input = add_input_layer(pcg, input_shape); + parallel_tensor_guid_t t_partitioned_input = + add_partition_layer(pcg, ff_dim_t{0_n}, degree, t_input); + + TensorShape casted_input_shape = + get_reduced_shape(get_parallel_tensor_shape(pcg, t_input)); + + TensorShape weight_shape = + get_weight_shapes(conv2d_attrs, casted_input_shape) + .at(TensorSlotName::FILTER); + + parallel_tensor_guid_t t_replicated_weight = + add_replicate_layer(pcg, degree, add_weight_layer(pcg, weight_shape)); + + parallel_tensor_guid_t t_partitioned_conv2d = add_conv2d_layer( + pcg, conv2d_attrs, t_partitioned_input, t_replicated_weight); + + parallel_tensor_guid_t t_combine = + add_combine_layer(pcg, ff_dim_t{0_n}, degree, t_partitioned_conv2d); + + return sub_pcg_from_full_pcg(pcg); + }(); + + CHECK(sub_pcgs_are_isomorphic(result, correct)); + } + + TEST_CASE("create_partition_attention_combine") { + positive_int embed_dim = 8_p; + positive_int num_heads = 6_p; + positive_int degree = 2_p; + std::string attention_match = "attention_match"; + + Substitution sub = create_partition_attention_combine(num_heads, degree); + + TensorShape query_shape = TensorShape{ + TensorDims{ + FFOrdered{ + 12_p, + 16_p, + 10_p, + }, + }, + DataType::FLOAT, + }; + TensorShape key_shape = query_shape; + TensorShape value_shape = query_shape; + + MultiHeadAttentionAttrs attention_attrs = MultiHeadAttentionAttrs{ + /*embed_dim=*/embed_dim, + /*num_heads=*/num_heads, + /*kdim=*/embed_dim, + /*vdim=*/embed_dim, + /*dropout=*/0, + /*bias=*/false, + /*add_bias_kv=*/false, + /*add_zero_attn=*/false, + }; + + TensorShape weights_shape = throw_if_unexpected(get_weights_shape( + attention_attrs, query_shape, key_shape, value_shape)); + + SubParallelComputationGraph original_pcg = [&] { + ParallelComputationGraph pcg = empty_parallel_computation_graph(); + + parallel_tensor_guid_t t_query = add_input_layer(pcg, query_shape); + parallel_tensor_guid_t t_key = add_input_layer(pcg, key_shape); + parallel_tensor_guid_t t_value = add_input_layer(pcg, value_shape); + + parallel_tensor_guid_t t_weights = add_weight_layer(pcg, weights_shape); + + parallel_tensor_guid_t t_attention = add_attention_layer(pcg, + attention_attrs, + t_query, + t_key, + t_value, + t_weights, + attention_match); + + return sub_pcg_from_full_pcg(pcg); + }(); + + PCGPatternMatch match = [&] { + parallel_layer_guid_t match_layer = + get_parallel_layer_by_name(original_pcg, attention_match); + open_parallel_tensor_guid_t match_layer_query = + get_layer_inputs(original_pcg, match_layer).at(TensorSlotName::QUERY); + open_parallel_tensor_guid_t match_layer_key = + get_layer_inputs(original_pcg, match_layer).at(TensorSlotName::KEY); + open_parallel_tensor_guid_t match_layer_value = + get_layer_inputs(original_pcg, match_layer).at(TensorSlotName::VALUE); + open_parallel_tensor_guid_t match_layer_input_weights = + get_layer_inputs(original_pcg, match_layer) + .at(TensorSlotName::WEIGHT); + + return PCGPatternMatch{ + bidict{ + {PatternNode{Node{0}}, match_layer}, + }, + std::unordered_map{ + { + PatternInput{KwargDataflowGraphInput{0}}, + match_layer_query, + }, + { + PatternInput{KwargDataflowGraphInput{2}}, + match_layer_key, + }, + { + PatternInput{KwargDataflowGraphInput{4}}, + match_layer_value, + }, + { + PatternInput{KwargDataflowGraphInput{6}}, + match_layer_input_weights, + }}, + }; + }(); + + SubParallelComputationGraph result = + apply_substitution(original_pcg, sub, match); + + SubParallelComputationGraph correct = [&] { + ParallelComputationGraph pcg = empty_parallel_computation_graph(); + + parallel_tensor_guid_t t_query = add_partition_layer( + pcg, ff_dim_t{0_n}, degree, add_input_layer(pcg, query_shape)); + parallel_tensor_guid_t t_key = add_partition_layer( + pcg, ff_dim_t{0_n}, degree, add_input_layer(pcg, key_shape)); + parallel_tensor_guid_t t_value = add_partition_layer( + pcg, ff_dim_t{0_n}, degree, add_input_layer(pcg, value_shape)); + + parallel_tensor_guid_t t_weight = add_replicate_layer( + pcg, degree, add_weight_layer(pcg, weights_shape)); + + parallel_tensor_guid_t t_partitioned_attention = add_attention_layer( + pcg, attention_attrs, t_query, t_key, t_value, t_weight); + + parallel_tensor_guid_t t_combine = add_combine_layer( + pcg, ff_dim_t{0_n}, degree, t_partitioned_attention); + + return sub_pcg_from_full_pcg(pcg); + }(); + + CHECK(sub_pcgs_are_isomorphic(result, correct)); + } + + TEST_CASE("create_replicate_attention_reduce") { + positive_int embed_dim = 8_p; + positive_int num_heads = 6_p; + positive_int degree = 2_p; + std::string attention_match = "attention_match"; + + Substitution sub = create_replicate_attention_reduce(num_heads, degree); + + TensorShape query_shape = TensorShape{ + TensorDims{ + FFOrdered{ + 12_p, + 16_p, + 10_p, + }, + }, + DataType::FLOAT, + }; + TensorShape key_shape = query_shape; + TensorShape value_shape = query_shape; + + MultiHeadAttentionAttrs attention_attrs = MultiHeadAttentionAttrs{ + /*embed_dim=*/embed_dim, + /*num_heads=*/num_heads, + /*kdim=*/embed_dim, + /*vdim=*/embed_dim, + /*dropout=*/0, + /*bias=*/false, + /*add_bias_kv=*/false, + /*add_zero_attn=*/false, + }; + + TensorShape weight_shape = throw_if_unexpected(get_weights_shape( + attention_attrs, query_shape, key_shape, value_shape)); + + SubParallelComputationGraph original_pcg = [&] { + ParallelComputationGraph pcg = empty_parallel_computation_graph(); + + parallel_tensor_guid_t t_query = add_input_layer(pcg, query_shape); + parallel_tensor_guid_t t_key = add_input_layer(pcg, key_shape); + parallel_tensor_guid_t t_value = add_input_layer(pcg, value_shape); + + parallel_tensor_guid_t t_weight = add_weight_layer(pcg, weight_shape); + + parallel_tensor_guid_t attention_added = + add_attention_layer(pcg, + attention_attrs, + t_query, + t_key, + t_value, + t_weight, + attention_match); + + return sub_pcg_from_full_pcg(pcg); + }(); + + PCGPatternMatch match = [&] { + parallel_layer_guid_t match_layer = + get_parallel_layer_by_name(original_pcg, attention_match); + open_parallel_tensor_guid_t match_layer_query = + get_layer_inputs(original_pcg, match_layer).at(TensorSlotName::QUERY); + open_parallel_tensor_guid_t match_layer_key = + get_layer_inputs(original_pcg, match_layer).at(TensorSlotName::KEY); + open_parallel_tensor_guid_t match_layer_value = + get_layer_inputs(original_pcg, match_layer).at(TensorSlotName::VALUE); + open_parallel_tensor_guid_t match_layer_input_weights = + get_layer_inputs(original_pcg, match_layer) + .at(TensorSlotName::WEIGHT); + + return PCGPatternMatch{ + bidict{ + {PatternNode{Node{0}}, match_layer}, + }, + std::unordered_map{ + { + PatternInput{KwargDataflowGraphInput{0}}, + match_layer_query, + }, + { + PatternInput{KwargDataflowGraphInput{2}}, + match_layer_key, + }, + { + PatternInput{KwargDataflowGraphInput{4}}, + match_layer_value, + }, + { + PatternInput{KwargDataflowGraphInput{6}}, + match_layer_input_weights, + }}, + }; + }(); + + SubParallelComputationGraph result = + apply_substitution(original_pcg, sub, match); + + SubParallelComputationGraph correct = [&] { + ParallelComputationGraph pcg = empty_parallel_computation_graph(); + + parallel_tensor_guid_t t_query = + add_replicate_layer(pcg, degree, add_input_layer(pcg, query_shape)); + parallel_tensor_guid_t t_key = + add_replicate_layer(pcg, degree, add_input_layer(pcg, key_shape)); + parallel_tensor_guid_t t_value = + add_replicate_layer(pcg, degree, add_input_layer(pcg, value_shape)); + + parallel_tensor_guid_t t_weight = add_partition_layer( + pcg, ff_dim_t{1_n}, degree, add_weight_layer(pcg, weight_shape)); + + parallel_tensor_guid_t t_replicated_attention = add_attention_layer( + pcg, attention_attrs, t_query, t_key, t_value, t_weight); + + parallel_tensor_guid_t t_reduction = + add_reduction_layer(pcg, degree, t_replicated_attention); + + return sub_pcg_from_full_pcg(pcg); + }(); + + CHECK(sub_pcgs_are_isomorphic(result, correct)); + } + + TEST_CASE("create_partition_softmax_combine") { + positive_int degree = 2_p; + ff_dim_t softmax_dim = ff_dim_t{1_n}; + ff_dim_t partition_dim = ff_dim_t{0_n}; + std::string softmax_match = "softmax_match"; + + Substitution sub = + create_partition_softmax_combine(softmax_dim, partition_dim, degree); + + TensorShape input_shape = TensorShape{ + TensorDims{ + FFOrdered{ + 10_p, + 10_p, + }, + }, + DataType::FLOAT, + }; + + SoftmaxAttrs softmax_attrs = SoftmaxAttrs{ + /*softmax_dim=*/softmax_dim, + }; + + SubParallelComputationGraph original_pcg = [&] { + ParallelComputationGraph pcg = empty_parallel_computation_graph(); + + parallel_tensor_guid_t t_input = add_input_layer(pcg, input_shape); + + parallel_tensor_guid_t t_softmax = add_single_output_layer( + pcg, + make_layer_attrs(softmax_attrs, softmax_match), + {{TensorSlotName::INPUT, t_input}}, + {}); + + return sub_pcg_from_full_pcg(pcg); + }(); + + PCGPatternMatch match = [&] { + parallel_layer_guid_t match_layer = + get_parallel_layer_by_name(original_pcg, softmax_match); + open_parallel_tensor_guid_t match_layer_input = + get_layer_inputs(original_pcg, match_layer).at(TensorSlotName::INPUT); + + return PCGPatternMatch{ + bidict{ + {PatternNode{Node{0}}, match_layer}, + }, + std::unordered_map{{ + PatternInput{KwargDataflowGraphInput{0}}, + match_layer_input, + }}, + }; + }(); + + SubParallelComputationGraph result = + apply_substitution(original_pcg, sub, match); + + SubParallelComputationGraph correct = [&] { + ParallelComputationGraph pcg = empty_parallel_computation_graph(); + + parallel_tensor_guid_t t_partitioned_input = add_partition_layer( + pcg, partition_dim, degree, add_input_layer(pcg, input_shape)); + + parallel_tensor_guid_t t_partitioned_softmax = add_single_output_layer( + pcg, + make_layer_attrs(softmax_attrs), + {{TensorSlotName::INPUT, t_partitioned_input}}, + {}); + + parallel_tensor_guid_t t_combine = + add_combine_layer(pcg, partition_dim, degree, t_partitioned_softmax); + + return sub_pcg_from_full_pcg(pcg); + }(); + + CHECK(sub_pcgs_are_isomorphic(result, correct)); + } + + TEST_CASE("create_partition_add_combine") { + positive_int degree = 2_p; + ff_dim_t parallel_dim = ff_dim_t{1_n}; + std::string add_match = "add_match"; + + Substitution sub = create_partition_add_combine(parallel_dim, degree); + + TensorShape lhs_shape = TensorShape{ + TensorDims{ + FFOrdered{ + 10_p, + 15_p, + }, + }, + DataType::FLOAT, + }; + + TensorShape rhs_shape = lhs_shape; + + ElementBinaryAttrs add_attrs = ElementBinaryAttrs{ + OperatorType::EW_ADD, + DataType::FLOAT, + false, + false, + }; + + SubParallelComputationGraph original_pcg = [&] { + ParallelComputationGraph pcg = empty_parallel_computation_graph(); + + parallel_tensor_guid_t t_lhs = add_input_layer(pcg, lhs_shape); + parallel_tensor_guid_t t_rhs = add_input_layer(pcg, rhs_shape); + + parallel_tensor_guid_t t_add = + add_single_output_layer(pcg, + make_layer_attrs(add_attrs, add_match), + { + {TensorSlotName::LHS_INPUT, t_lhs}, + {TensorSlotName::RHS_INPUT, t_rhs}, + }, + {}); + + return sub_pcg_from_full_pcg(pcg); + }(); + + PCGPatternMatch match = [&] { + parallel_layer_guid_t match_layer = + get_parallel_layer_by_name(original_pcg, add_match); + open_parallel_tensor_guid_t add_match_layer_lhs = + get_layer_inputs(original_pcg, match_layer) + .at(TensorSlotName::LHS_INPUT); + open_parallel_tensor_guid_t add_match_layer_rhs = + get_layer_inputs(original_pcg, match_layer) + .at(TensorSlotName::RHS_INPUT); + + return PCGPatternMatch{ + bidict{ + {PatternNode{Node{0}}, match_layer}, + }, + std::unordered_map{ + { + PatternInput{KwargDataflowGraphInput{0}}, + add_match_layer_lhs, + }, + { + PatternInput{KwargDataflowGraphInput{2}}, + add_match_layer_rhs, + }}, + }; + }(); + + SubParallelComputationGraph result = + apply_substitution(original_pcg, sub, match); + + SubParallelComputationGraph correct = [&] { + ParallelComputationGraph pcg = empty_parallel_computation_graph(); + + parallel_tensor_guid_t t_lhs = add_partition_layer( + pcg, parallel_dim, degree, add_input_layer(pcg, lhs_shape)); + parallel_tensor_guid_t t_rhs = add_partition_layer( + pcg, parallel_dim, degree, add_input_layer(pcg, rhs_shape)); + + parallel_tensor_guid_t t_partitioned_add = + add_single_output_layer(pcg, + make_layer_attrs(add_attrs, add_match), + { + {TensorSlotName::LHS_INPUT, t_lhs}, + {TensorSlotName::RHS_INPUT, t_rhs}, + }, + {}); + + parallel_tensor_guid_t t_combine = + add_combine_layer(pcg, parallel_dim, degree, t_partitioned_add); + + return sub_pcg_from_full_pcg(pcg); + }(); + + CHECK(sub_pcgs_are_isomorphic(result, correct)); + } + + TEST_CASE("create_partition_relu_combine") { + positive_int degree = 2_p; + ff_dim_t parallel_dim = ff_dim_t{1_n}; + std::string relu_match = "relu_match"; + + Substitution sub = create_partition_relu_combine(parallel_dim, degree); + + TensorShape input_shape = TensorShape{ + TensorDims{ + FFOrdered{ + 10_p, + 10_p, + }, + }, + DataType::FLOAT, + }; + + ElementUnaryAttrs relu_attrs = ElementUnaryAttrs{ + OperatorType::RELU, + std::nullopt, + }; + + SubParallelComputationGraph original_pcg = [&] { + ParallelComputationGraph pcg = empty_parallel_computation_graph(); + + parallel_tensor_guid_t t_input = add_input_layer(pcg, input_shape); + + parallel_tensor_guid_t t_relu = + add_single_output_layer(pcg, + make_layer_attrs(relu_attrs, relu_match), + {{TensorSlotName::INPUT, t_input}}, + {}); + + return sub_pcg_from_full_pcg(pcg); + }(); + + PCGPatternMatch match = [&] { + parallel_layer_guid_t match_layer = + get_parallel_layer_by_name(original_pcg, relu_match); + open_parallel_tensor_guid_t match_layer_input = + get_layer_inputs(original_pcg, match_layer).at(TensorSlotName::INPUT); + + return PCGPatternMatch{ + bidict{ + {PatternNode{Node{0}}, match_layer}, + }, + std::unordered_map{{ + PatternInput{KwargDataflowGraphInput{0}}, + match_layer_input, + }}, + }; + }(); + + SubParallelComputationGraph result = + apply_substitution(original_pcg, sub, match); + + SubParallelComputationGraph correct = [&] { + ParallelComputationGraph pcg = empty_parallel_computation_graph(); + + parallel_tensor_guid_t t_input = add_partition_layer( + pcg, parallel_dim, degree, add_input_layer(pcg, input_shape)); + + parallel_tensor_guid_t t_relu = + add_single_output_layer(pcg, + make_layer_attrs(relu_attrs), + {{TensorSlotName::INPUT, t_input}}, + {}); + + parallel_tensor_guid_t t_combine = + add_combine_layer(pcg, parallel_dim, degree, t_relu); + + return sub_pcg_from_full_pcg(pcg); + }(); + + CHECK(sub_pcgs_are_isomorphic(result, correct)); + } + + TEST_CASE("create_fuse_linear_activation") { + Substitution sub = create_fuse_linear_activation(Activation::SIGMOID); + + std::string mm_match = "mm_match"; + std::string relu_match = "relu_match"; + + TensorShape input_shape = TensorShape{ + TensorDims{ + FFOrdered{ + 4_p, + 10_p, + }, + }, + DataType::FLOAT, + }; + + SubParallelComputationGraph pcg = [&] { + ParallelComputationGraphBuilder b; + parallel_tensor_guid_t t = b.create_input_tensor(input_shape); + t = b.dense(t, + /*outDim=*/4_p, + /*activation=*/std::nullopt, + /*use_bias=*/false, + /*data_type=*/DataType::FLOAT, + /*kernel_initializer=*/std::nullopt, + /*bias_initializer=*/std::nullopt, + /*name=*/mm_match); + t = b.relu(t, + /*name=*/relu_match); + + return sub_pcg_from_full_pcg(b.pcg); + }(); + + PCGPatternMatch match = [&] { + parallel_layer_guid_t mm_match_layer = + get_parallel_layer_by_name(pcg, mm_match); + parallel_layer_guid_t relu_match_layer = + get_parallel_layer_by_name(pcg, relu_match); + open_parallel_tensor_guid_t mm_match_layer_input_activations = + get_layer_inputs(pcg, mm_match_layer).at(TensorSlotName::INPUT); + open_parallel_tensor_guid_t mm_match_layer_input_weights = + get_layer_inputs(pcg, mm_match_layer).at(TensorSlotName::WEIGHT); + + return PCGPatternMatch{ + bidict{ + {PatternNode{Node{0}}, mm_match_layer}, + {PatternNode{Node{1}}, relu_match_layer}, + }, + std::unordered_map{ + { + PatternInput{KwargDataflowGraphInput{0}}, + mm_match_layer_input_activations, + }, + { + PatternInput{KwargDataflowGraphInput{2}}, + mm_match_layer_input_weights, + }}, + }; + }(); + + SubParallelComputationGraph result = apply_substitution(pcg, sub, match); + + SubParallelComputationGraph correct = [&] { + ParallelComputationGraphBuilder b; + parallel_tensor_guid_t t = b.create_input_tensor(input_shape); + t = b.dense(t, + /*outDim=*/4_p, + /*activation=*/Activation::SIGMOID, + /*use_bias=*/false, + /*data_type=*/DataType::FLOAT, + /*kernel_initializer=*/std::nullopt, + /*bias_initializer=*/std::nullopt, + /*name=*/std::nullopt); + + return sub_pcg_from_full_pcg(b.pcg); + }(); + + CHECK(sub_pcgs_are_isomorphic(result, correct)); } } diff --git a/lib/utils/include/utils/positive_int/positive_range.h b/lib/utils/include/utils/positive_int/positive_range.h new file mode 100644 index 0000000000..f064f766c8 --- /dev/null +++ b/lib/utils/include/utils/positive_int/positive_range.h @@ -0,0 +1,13 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_POSITIVE_INT_POSITIVE_RANGE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_POSITIVE_INT_POSITIVE_RANGE_H + +#include "utils/positive_int/positive_int.h" + +namespace FlexFlow { + +std::vector + positive_range(positive_int start, positive_int end, int step = 1); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/src/utils/positive_int/positive_range.cc b/lib/utils/src/utils/positive_int/positive_range.cc new file mode 100644 index 0000000000..bb52f0b4d9 --- /dev/null +++ b/lib/utils/src/utils/positive_int/positive_range.cc @@ -0,0 +1,14 @@ +#include "utils/positive_int/positive_range.h" +#include "utils/containers/range.h" +#include "utils/containers/transform.h" + +namespace FlexFlow { + +std::vector + positive_range(positive_int start, positive_int end, int step) { + return transform( + range(start.int_from_positive_int(), end.int_from_positive_int(), step), + [](int x) { return positive_int{x}; }); +} + +} // namespace FlexFlow