Skip to content

Commit 47cc58a

Browse files
author
Victor Li
committed
nonnegative_int additions, code cleanup, etc.
1 parent 30f2b6e commit 47cc58a

File tree

31 files changed

+293
-503
lines changed

31 files changed

+293
-503
lines changed

lib/compiler/src/compiler/allowed_machine_views.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,8 @@ static std::unordered_set<MachineView>
6060
transform(range(1, max_stride_upper_bound + 1),
6161
[](int stride) { return stride_t{stride}; });
6262
std::unordered_multiset<std::vector<stride_t>> raw_stride_vectors =
63-
cartesian_product(replicate(tensor_dims.size(), single_stride_range));
63+
cartesian_product(replicate(nonnegative_int{tensor_dims.size()},
64+
single_stride_range));
6465
std::unordered_multiset<MultiDimensionalStride> strides =
6566
transform(raw_stride_vectors, [](auto const &stride_vec) {
6667
return MultiDimensionalStride{stride_vec};

lib/compiler/src/compiler/machine_mapping/machine_mapping.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@ namespace FlexFlow {
77

88
MachineMapping combine_disjoint_mappings(MachineMapping const &s1,
99
MachineMapping const &s2) {
10-
return MachineMapping{merge_maps(s1.machine_views, s2.machine_views)};
10+
return MachineMapping{
11+
merge_disjoint_maps(s1.machine_views, s2.machine_views)};
1112
}
1213

1314
bool nodes_are_disjoint(MachineMapping const &m1, MachineMapping const &m2) {

lib/compiler/src/compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@ ParallelLayerGuidObliviousMachineMapping binary_combine_mappings(
1010
ParallelLayerGuidObliviousMachineMapping const &lhs,
1111
ParallelLayerGuidObliviousMachineMapping const &rhs) {
1212
return ParallelLayerGuidObliviousMachineMapping{
13-
merge_maps(map_keys(lhs.raw_mapping, nest_inside_left_child),
14-
map_keys(rhs.raw_mapping, nest_inside_right_child)),
13+
merge_disjoint_maps(map_keys(lhs.raw_mapping, nest_inside_left_child),
14+
map_keys(rhs.raw_mapping, nest_inside_right_child)),
1515
};
1616
}
1717

lib/substitutions/include/substitutions/operator_pattern/operator_attribute_constraint.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@ OperatorAttributeConstraint op_type_equals_constraint(OperatorType);
99

1010
OperatorAttributeConstraint op_attr_key_equals(OperatorAttributeKey,
1111
OperatorAttributeValue const &);
12-
OperatorAttributeConstraint op_attr_key_divisible_by(OperatorAttributeKey,
13-
int denominator);
12+
OperatorAttributeConstraint
13+
op_attr_key_divisible_by(OperatorAttributeKey, nonnegative_int denominator);
1414
OperatorAttributeConstraint
1515
make_equals_constraint(OperatorAttributeExpr const &,
1616
OperatorAttributeValue const &);

lib/substitutions/include/substitutions/operator_pattern/operator_attribute_value.variant.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ includes = [
2121
"op-attrs/tensor_shape.dtg.h",
2222
"op-attrs/datatype.dtg.h",
2323
"<cstddef>",
24+
"utils/nonnegative_int/nonnegative_int.h",
2425
]
2526

2627
src_includes = [
@@ -74,3 +75,6 @@ type = "::FlexFlow::TensorDims"
7475

7576
[[values]]
7677
type = "::FlexFlow::DataType"
78+
79+
[[values]]
80+
type = "::FlexFlow::nonnegative_int"

lib/substitutions/include/substitutions/substitution_builder.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ struct SubstitutionBuilder {
2626
std::vector<OutputGraphExprValue>
2727
add_output_graph_node(OutputOperatorAttrsAssignment const &node_expr,
2828
std::vector<OutputGraphExprValue> const &inputs,
29-
int num_outputs);
29+
nonnegative_int num_outputs);
3030

3131
PatternNode pattern_node_named(std::string const &) const;
3232
PatternInput pattern_input_named(std::string const &) const;

lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_pattern.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,13 @@
22
#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_TENSOR_ATTRIBUTE_PATTERN_H
33

44
#include "substitutions/tensor_pattern/tensor_attribute_pattern.dtg.h"
5+
#include "utils/nonnegative_int/nonnegative_int.h"
56

67
namespace FlexFlow {
78

89
TensorAttributePattern tensor_attribute_pattern_match_all();
9-
TensorAttributePattern tensor_attr_pattern_require_num_dims(int num_dims);
10+
TensorAttributePattern
11+
tensor_attr_pattern_require_num_dims(nonnegative_int num_dims);
1012

1113
} // namespace FlexFlow
1214

lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_value.variant.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,11 @@ includes = [
1212
"<vector>",
1313
"utils/hash/vector.h",
1414
"utils/fmt/vector.h",
15+
"utils/nonnegative_int/nonnegative_int.h",
1516
]
1617

1718
[[values]]
18-
type = "size_t"
19+
type = "::FlexFlow::nonnegative_int"
1920

2021
[[values]]
21-
type = "std::vector<size_t>"
22+
type = "std::vector<::FlexFlow::nonnegative_int>"

lib/substitutions/src/substitutions/apply_substitution/apply_substitution.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@ SubParallelComputationGraph
4646
std::unordered_map<parallel_layer_guid_t, ParallelLayerAttrs>
4747
post_node_data_from_sub = output_graph_data.node_data;
4848

49-
return merge_maps(post_node_data_from_orig, post_node_data_from_sub);
49+
return merge_disjoint_maps(post_node_data_from_orig,
50+
post_node_data_from_sub);
5051
}();
5152

5253
std::unordered_set<SubParallelComputationGraphEdge> post_edges = [&] {
@@ -147,7 +148,8 @@ SubParallelComputationGraph
147148

148149
std::unordered_map<open_parallel_tensor_guid_t, ParallelTensorAttrs>
149150
post_value_data_from_sub = output_graph_data.value_data;
150-
return merge_maps(post_value_data_from_orig, post_value_data_from_sub);
151+
return merge_disjoint_maps(post_value_data_from_orig,
152+
post_value_data_from_sub);
151153
}();
152154

153155
SubParallelComputationGraphData post_data = SubParallelComputationGraphData{

lib/substitutions/src/substitutions/operator_pattern/operator_attribute_constraint.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,9 @@ OperatorAttributeConstraint
2020
};
2121
}
2222

23-
OperatorAttributeConstraint op_attr_key_divisible_by(OperatorAttributeKey key,
24-
int denominator) {
23+
OperatorAttributeConstraint
24+
op_attr_key_divisible_by(OperatorAttributeKey key,
25+
nonnegative_int denominator) {
2526
return OperatorAttributeConstraint{
2627
ConstraintType::DIVISIBLE_BY,
2728
OperatorAttributeExpr{key},

0 commit comments

Comments
 (0)