Skip to content

Commit 7b7c846

Browse files
committed
Remove variadic args, more kwarg graph algorithms
1 parent 8c6f774 commit 7b7c846

File tree

105 files changed

+1510
-1194
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

105 files changed

+1510
-1194
lines changed

lib/op-attrs/include/op-attrs/ops/attention.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
#include "op-attrs/ops/attention/multihead_attention_parallel_inputs.dtg.h"
88
#include "op-attrs/ops/attention_attrs.dtg.h"
99
#include "op-attrs/parallel_tensor_shape.dtg.h"
10-
#include "utils/singular_or_variadic.dtg.h"
1110
#include "op-attrs/tensor_shape.dtg.h"
1211
#include "op-attrs/tensor_slot_name.dtg.h"
1312
#include <tl/expected.hpp>
@@ -64,7 +63,7 @@ tl::expected<TensorShape, std::string>
6463
TensorShape const &input_k,
6564
TensorShape const &input_v);
6665

67-
tl::expected<std::unordered_map<TensorSlotName, SingularOrVariadic<TensorShape>>, std::string>
66+
tl::expected<std::unordered_map<TensorSlotName, TensorShape>, std::string>
6867
get_weight_shapes(MultiHeadAttentionAttrs const &,
6968
TensorShape const &input_q,
7069
TensorShape const &input_k,
@@ -107,7 +106,7 @@ tl::expected<ParallelTensorShape, std::string>
107106
ParallelTensorShape const &input_k,
108107
ParallelTensorShape const &input_v);
109108

110-
tl::expected<std::unordered_map<TensorSlotName, SingularOrVariadic<ParallelTensorShape>>, std::string>
109+
tl::expected<std::unordered_map<TensorSlotName, ParallelTensorShape>, std::string>
111110
get_weight_shapes(MultiHeadAttentionAttrs const &,
112111
ParallelTensorShape const &input_q,
113112
ParallelTensorShape const &input_k,

lib/op-attrs/include/op-attrs/ops/batch_norm.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
#include "op-attrs/ops/batch_norm_attrs.dtg.h"
77
#include "op-attrs/parallel_tensor_dim_degrees.dtg.h"
88
#include "op-attrs/parallel_tensor_shape.dtg.h"
9-
#include "utils/singular_or_variadic.dtg.h"
109
#include "op-attrs/tensor_shape.dtg.h"
1110
#include "op-attrs/tensor_slot_name.dtg.h"
1211
#include <tl/expected.hpp>
@@ -23,7 +22,7 @@ tl::expected<TensorShape, std::string>
2322
tl::expected<TensorShape, std::string>
2423
get_beta_weights_shape(BatchNormAttrs const &, TensorShape const &);
2524

26-
tl::expected<std::unordered_map<TensorSlotName, SingularOrVariadic<TensorShape>>, std::string>
25+
tl::expected<std::unordered_map<TensorSlotName, TensorShape>, std::string>
2726
get_weight_shapes(BatchNormAttrs const &attrs,
2827
TensorShape const &input_shape);
2928

@@ -50,7 +49,7 @@ tl::expected<ParallelTensorShape, std::string>
5049
tl::expected<ParallelTensorShape, std::string>
5150
get_beta_weights_shape(BatchNormAttrs const &, ParallelTensorShape const &);
5251

53-
tl::expected<std::unordered_map<TensorSlotName, SingularOrVariadic<ParallelTensorShape>>, std::string>
52+
tl::expected<std::unordered_map<TensorSlotName, ParallelTensorShape>, std::string>
5453
get_weight_shapes(BatchNormAttrs const &attrs,
5554
ParallelTensorShape const &input_shape);
5655

lib/op-attrs/include/op-attrs/ops/conv_2d.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
#include "op-attrs/initializer_attrs.dtg.h"
66
#include "op-attrs/ops/conv_2d_attrs.dtg.h"
77
#include "op-attrs/parallel_tensor_shape.h"
8-
#include "utils/singular_or_variadic.dtg.h"
98
#include "op-attrs/tensor_shape.h"
109
#include "op-attrs/tensor_slot_name.dtg.h"
1110

@@ -20,7 +19,7 @@ TensorShape get_bias_shape(Conv2DAttrs const &attrs, TensorShape const &input);
2019
TensorShape get_output_shape(Conv2DAttrs const &attrs,
2120
TensorShape const &input);
2221

23-
std::unordered_map<TensorSlotName, SingularOrVariadic<TensorShape>>
22+
std::unordered_map<TensorSlotName, TensorShape>
2423
get_weight_shapes(Conv2DAttrs const &attrs,
2524
TensorShape const &input_shape);
2625

@@ -31,7 +30,7 @@ ParallelTensorShape get_bias_shape(Conv2DAttrs const &attrs,
3130
ParallelTensorShape get_output_shape(Conv2DAttrs const &attrs,
3231
ParallelTensorShape const &input_shape);
3332

34-
std::unordered_map<TensorSlotName, SingularOrVariadic<ParallelTensorShape>>
33+
std::unordered_map<TensorSlotName, ParallelTensorShape>
3534
get_weight_shapes(Conv2DAttrs const &attrs,
3635
ParallelTensorShape const &input_shape);
3736

lib/op-attrs/include/op-attrs/ops/layer_norm.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
#include "op-attrs/initializer_attrs.dtg.h"
66
#include "op-attrs/ops/layer_norm_attrs.dtg.h"
77
#include "op-attrs/parallel_tensor_shape.dtg.h"
8-
#include "utils/singular_or_variadic.dtg.h"
98
#include "op-attrs/tensor_shape.dtg.h"
109
#include "op-attrs/tensor_slot_name.dtg.h"
1110
#include <tl/expected.hpp>
@@ -22,7 +21,7 @@ tl::expected<TensorShape, std::string>
2221
tl::expected<TensorShape, std::string>
2322
get_beta_weights_shape(LayerNormAttrs const &, TensorShape const &);
2423

25-
tl::expected<std::unordered_map<TensorSlotName, SingularOrVariadic<TensorShape>>, std::string>
24+
tl::expected<std::unordered_map<TensorSlotName, TensorShape>, std::string>
2625
get_weight_shapes(LayerNormAttrs const &attrs,
2726
TensorShape const &input_shape);
2827

@@ -34,7 +33,7 @@ tl::expected<ParallelTensorShape, std::string>
3433
tl::expected<ParallelTensorShape, std::string>
3534
get_beta_weights_shape(LayerNormAttrs const &, ParallelTensorShape const &);
3635

37-
tl::expected<std::unordered_map<TensorSlotName, SingularOrVariadic<ParallelTensorShape>>, std::string>
36+
tl::expected<std::unordered_map<TensorSlotName, ParallelTensorShape>, std::string>
3837
get_weight_shapes(LayerNormAttrs const &attrs,
3938
ParallelTensorShape const &input_shape);
4039

lib/op-attrs/include/op-attrs/ops/linear.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
#include "op-attrs/parallel_tensor_dim_degrees.dtg.h"
1111
#include "op-attrs/parallel_tensor_shape.dtg.h"
1212
#include "op-attrs/parallel_tensor_space_to_parallel_tensor_space_mapping.dtg.h"
13-
#include "utils/singular_or_variadic.dtg.h"
1413
#include "op-attrs/tensor_shape.dtg.h"
1514
#include "op-attrs/tensor_slot_name.dtg.h"
1615
#include "utils/record_formatter.h"
@@ -30,7 +29,7 @@ tl::expected<TensorShape, std::string> get_bias_shape(LinearAttrs const &attrs,
3029
tl::expected<TensorShape, std::string>
3130
get_output_shape(LinearAttrs const &attrs, TensorShape const &input);
3231

33-
tl::expected<std::unordered_map<TensorSlotName, SingularOrVariadic<TensorShape>>, std::string>
32+
tl::expected<std::unordered_map<TensorSlotName, TensorShape>, std::string>
3433
get_weight_shapes(LinearAttrs const &attrs, TensorShape const &input_shape);
3534

3635
ParallelTensorDimDegrees
@@ -52,7 +51,7 @@ tl::expected<ParallelTensorShape, std::string>
5251
get_output_shape(LinearAttrs const &attrs,
5352
ParallelTensorShape const &input);
5453

55-
tl::expected<std::unordered_map<TensorSlotName, SingularOrVariadic<ParallelTensorShape>>, std::string>
54+
tl::expected<std::unordered_map<TensorSlotName, ParallelTensorShape>, std::string>
5655
get_weight_shapes(LinearAttrs const &attrs,
5756
ParallelTensorShape const &input_shape);
5857

lib/op-attrs/include/op-attrs/shape_inference.h

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,27 +4,26 @@
44
#include "op-attrs/computation_graph_op_attrs.dtg.h"
55
#include "op-attrs/parallel_tensor_shape.dtg.h"
66
#include "op-attrs/pcg_operator_attrs.dtg.h"
7-
#include "utils/singular_or_variadic.dtg.h"
87
#include "op-attrs/tensor_slot_name.dtg.h"
98
#include <vector>
109

1110
namespace FlexFlow {
1211

13-
std::unordered_map<TensorSlotName, SingularOrVariadic<TensorShape>>
12+
std::unordered_map<TensorSlotName, TensorShape>
1413
get_output_shapes(ComputationGraphOpAttrs const &,
15-
std::unordered_map<TensorSlotName, SingularOrVariadic<TensorShape>> const &input_shapes);
14+
std::unordered_map<TensorSlotName, TensorShape> const &input_shapes);
1615

17-
std::unordered_map<TensorSlotName, SingularOrVariadic<TensorShape>>
16+
std::unordered_map<TensorSlotName, TensorShape>
1817
get_weight_shapes(ComputationGraphOpAttrs const &,
19-
std::unordered_map<TensorSlotName, SingularOrVariadic<TensorShape>> const &input_shapes);
18+
std::unordered_map<TensorSlotName, TensorShape> const &input_shapes);
2019

21-
std::unordered_map<TensorSlotName, SingularOrVariadic<ParallelTensorShape>>
20+
std::unordered_map<TensorSlotName, ParallelTensorShape>
2221
get_output_shapes(PCGOperatorAttrs const &,
23-
std::unordered_map<TensorSlotName, SingularOrVariadic<ParallelTensorShape>> const &input_shapes);
22+
std::unordered_map<TensorSlotName, ParallelTensorShape> const &input_shapes);
2423

25-
std::unordered_map<TensorSlotName, SingularOrVariadic<ParallelTensorShape>>
24+
std::unordered_map<TensorSlotName, ParallelTensorShape>
2625
get_weight_shapes(PCGOperatorAttrs const &,
27-
std::unordered_map<TensorSlotName, SingularOrVariadic<ParallelTensorShape>> const &input_shapes);
26+
std::unordered_map<TensorSlotName, ParallelTensorShape> const &input_shapes);
2827

2928
} // namespace FlexFlow
3029

lib/op-attrs/include/op-attrs/tensor_slot_name.dtg.toml

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,3 +55,27 @@ name = "VALUE"
5555

5656
[[values]]
5757
name = "LOGIT"
58+
59+
[[values]]
60+
name = "INPUT_0"
61+
62+
[[values]]
63+
name = "INPUT_1"
64+
65+
[[values]]
66+
name = "INPUT_2"
67+
68+
[[values]]
69+
name = "INPUT_3"
70+
71+
[[values]]
72+
name = "OUTPUT_0"
73+
74+
[[values]]
75+
name = "OUTPUT_1"
76+
77+
[[values]]
78+
name = "OUTPUT_2"
79+
80+
[[values]]
81+
name = "OUTPUT_3"
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_TENSOR_SLOT_NAME_H
2+
#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_TENSOR_SLOT_NAME_H
3+
4+
#include "op-attrs/tensor_slot_name.dtg.h"
5+
6+
namespace FlexFlow {
7+
8+
std::vector<TensorSlotName> get_variadic_inputs_slot_name_sequence();
9+
std::vector<TensorSlotName> get_variadic_outputs_slot_name_sequence();
10+
11+
} // namespace FlexFlow
12+
13+
#endif

lib/op-attrs/src/op-attrs/ops/attention.cc

Lines changed: 10 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -233,34 +233,28 @@ tl::expected<TensorShape, std::string>
233233
};
234234
}
235235

236-
tl::expected<std::unordered_map<TensorSlotName, SingularOrVariadic<TensorShape>>, std::string>
236+
tl::expected<std::unordered_map<TensorSlotName, TensorShape>, std::string>
237237
get_weight_shapes(MultiHeadAttentionAttrs const &attrs,
238238
TensorShape const &input_q,
239239
TensorShape const &input_k,
240240
TensorShape const &input_v) {
241241

242-
std::unordered_map<TensorSlotName, SingularOrVariadic<TensorShape>> weight_shapes = {
242+
std::unordered_map<TensorSlotName, TensorShape> weight_shapes = {
243243
{
244244
TensorSlotName::WEIGHT,
245-
SingularOrVariadic<TensorShape>{
246-
PROPAGATE_ERR(get_weights_shape(attrs, input_q, input_k, input_v)),
247-
},
245+
PROPAGATE_ERR(get_weights_shape(attrs, input_q, input_k, input_v)),
248246
},
249247
};
250248

251249
if (attrs.bias) {
252250
weight_shapes.insert({
253251
TensorSlotName::INPUT_BIAS,
254-
SingularOrVariadic<TensorShape>{
255-
PROPAGATE_ERR(get_input_bias_shape(attrs, input_q, input_k, input_v)),
256-
},
252+
PROPAGATE_ERR(get_input_bias_shape(attrs, input_q, input_k, input_v)),
257253
});
258254

259255
weight_shapes.insert({
260256
TensorSlotName::OUTPUT_BIAS,
261-
SingularOrVariadic<TensorShape>{
262-
PROPAGATE_ERR(get_output_bias_shape(attrs, input_q, input_k, input_v)),
263-
},
257+
PROPAGATE_ERR(get_output_bias_shape(attrs, input_q, input_k, input_v)),
264258
});
265259
}
266260

@@ -422,34 +416,28 @@ positive_int get_oSize(TensorShape const &) {
422416
NOT_IMPLEMENTED();
423417
}
424418

425-
tl::expected<std::unordered_map<TensorSlotName, SingularOrVariadic<ParallelTensorShape>>, std::string>
419+
tl::expected<std::unordered_map<TensorSlotName, ParallelTensorShape>, std::string>
426420
get_weight_shapes(MultiHeadAttentionAttrs const &attrs,
427421
ParallelTensorShape const &input_q,
428422
ParallelTensorShape const &input_k,
429423
ParallelTensorShape const &input_v) {
430424

431-
std::unordered_map<TensorSlotName, SingularOrVariadic<ParallelTensorShape>> weight_shapes = {
425+
std::unordered_map<TensorSlotName, ParallelTensorShape> weight_shapes = {
432426
{
433427
TensorSlotName::WEIGHT,
434-
SingularOrVariadic<ParallelTensorShape>{
435-
PROPAGATE_ERR(get_weights_shape(attrs, input_q, input_k, input_v)),
436-
},
428+
PROPAGATE_ERR(get_weights_shape(attrs, input_q, input_k, input_v)),
437429
},
438430
};
439431

440432
if (attrs.bias) {
441433
weight_shapes.insert({
442434
TensorSlotName::INPUT_BIAS,
443-
SingularOrVariadic<ParallelTensorShape>{
444-
PROPAGATE_ERR(get_input_bias_shape(attrs, input_q, input_k, input_v)),
445-
},
435+
PROPAGATE_ERR(get_input_bias_shape(attrs, input_q, input_k, input_v)),
446436
});
447437

448438
weight_shapes.insert({
449439
TensorSlotName::OUTPUT_BIAS,
450-
SingularOrVariadic<ParallelTensorShape>{
451-
PROPAGATE_ERR(get_output_bias_shape(attrs, input_q, input_k, input_v)),
452-
},
440+
PROPAGATE_ERR(get_output_bias_shape(attrs, input_q, input_k, input_v)),
453441
});
454442
}
455443

lib/op-attrs/src/op-attrs/ops/batch_norm.cc

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ tl::expected<TensorShape, std::string>
9696
return get_gamma_weights_shape(attrs, input_shape);
9797
}
9898

99-
tl::expected<std::unordered_map<TensorSlotName, SingularOrVariadic<TensorShape>>, std::string>
99+
tl::expected<std::unordered_map<TensorSlotName, TensorShape>, std::string>
100100
get_weight_shapes(BatchNormAttrs const &attrs,
101101
TensorShape const &input_shape) {
102102

@@ -105,18 +105,14 @@ tl::expected<std::unordered_map<TensorSlotName, SingularOrVariadic<TensorShape>>
105105
TensorShape beta_shape =
106106
PROPAGATE_ERR(get_beta_weights_shape(attrs, input_shape));
107107

108-
return std::unordered_map<TensorSlotName, SingularOrVariadic<TensorShape>>{
108+
return std::unordered_map<TensorSlotName, TensorShape>{
109109
{
110110
TensorSlotName::GAMMA,
111-
SingularOrVariadic<TensorShape>{
112-
gamma_shape,
113-
},
111+
gamma_shape,
114112
},
115113
{
116114
TensorSlotName::BETA,
117-
SingularOrVariadic<TensorShape>{
118-
beta_shape,
119-
},
115+
beta_shape,
120116
},
121117
};
122118
}
@@ -318,7 +314,7 @@ tl::expected<ParallelTensorShape, std::string>
318314
return lift_to_parallel_with_degrees(unpar, degrees);
319315
}
320316

321-
tl::expected<std::unordered_map<TensorSlotName, SingularOrVariadic<ParallelTensorShape>>, std::string>
317+
tl::expected<std::unordered_map<TensorSlotName, ParallelTensorShape>, std::string>
322318
get_weight_shapes(BatchNormAttrs const &attrs,
323319
ParallelTensorShape const &input_shape) {
324320

@@ -327,18 +323,14 @@ tl::expected<std::unordered_map<TensorSlotName, SingularOrVariadic<ParallelTenso
327323
ParallelTensorShape beta_shape =
328324
PROPAGATE_ERR(get_beta_weights_shape(attrs, input_shape));
329325

330-
return std::unordered_map<TensorSlotName, SingularOrVariadic<ParallelTensorShape>>{
326+
return std::unordered_map<TensorSlotName, ParallelTensorShape>{
331327
{
332328
TensorSlotName::GAMMA,
333-
SingularOrVariadic<ParallelTensorShape>{
334-
gamma_shape,
335-
},
329+
gamma_shape,
336330
},
337331
{
338332
TensorSlotName::BETA,
339-
SingularOrVariadic<ParallelTensorShape>{
340-
beta_shape,
341-
},
333+
beta_shape,
342334
},
343335
};
344336
}

0 commit comments

Comments
 (0)