Skip to content

Commit 7fe5b9c

Browse files
committed
Finish kwarg fixes for ComputationGraph
1 parent 7b7c846 commit 7fe5b9c

File tree

49 files changed

+1014
-322
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

49 files changed

+1014
-322
lines changed

lib/op-attrs/include/op-attrs/ops/embedding.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include "op-attrs/ops/embedding_attrs.dtg.h"
66
#include "op-attrs/parallel_tensor_shape.h"
77
#include "op-attrs/tensor_shape.h"
8+
#include "op-attrs/tensor_slot_name.dtg.h"
89
#include "utils/record_formatter.h"
910
#include <tl/expected.hpp>
1011

@@ -28,7 +29,7 @@ tl::expected<ParallelTensorShape, std::string>
2829
* see
2930
* https://github.com/pytorch/pytorch/blob/1eba9b3aa3c43f86f4a2c807ac8e12c4a7767340/torch/nn/modules/sparse.py#L180-L182
3031
*/
31-
std::vector<InitializerAttrs> get_initializers(
32+
std::unordered_map<TensorSlotName, InitializerAttrs> get_initializers(
3233
EmbeddingAttrs const &,
3334
std::optional<InitializerAttrs> const &initializer_attrs = std::nullopt);
3435

lib/op-attrs/src/op-attrs/ops/embedding.cc

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ tl::expected<ParallelTensorShape, std::string>
130130
unpar, sum_degree, discard_copy_degree, shard_degrees);
131131
}
132132

133-
std::vector<InitializerAttrs> get_initializers(
133+
std::unordered_map<TensorSlotName, InitializerAttrs> get_initializers(
134134
EmbeddingAttrs const &,
135135
std::optional<InitializerAttrs> const &maybe_initializer_attrs) {
136136
InitializerAttrs default_initializer_attrs = InitializerAttrs{
@@ -141,7 +141,12 @@ std::vector<InitializerAttrs> get_initializers(
141141
},
142142
};
143143

144-
return {maybe_initializer_attrs.value_or(default_initializer_attrs)};
144+
return {
145+
{
146+
TensorSlotName::WEIGHT,
147+
maybe_initializer_attrs.value_or(default_initializer_attrs),
148+
},
149+
};
145150
}
146151

147152
} // namespace FlexFlow

lib/pcg/include/pcg/computation_graph/layer_added_result.dtg.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,13 @@ features = [
99
includes = [
1010
"pcg/layer_guid_t.dtg.h",
1111
"pcg/tensor_guid_t.dtg.h",
12-
"utils/fmt/vector.h",
1312
"op-attrs/tensor_slot_name.dtg.h",
1413
]
1514

15+
src_includes = [
16+
"utils/fmt/unordered_map.h",
17+
]
18+
1619
[[fields]]
1720
name = "layer"
1821
type = "::FlexFlow::layer_guid_t"

lib/pcg/include/pcg/computation_graph_builder.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -255,11 +255,11 @@ struct ComputationGraphBuilder {
255255
TensorShape get_shape(tensor_guid_t const &) const;
256256

257257
private:
258-
std::vector<tensor_guid_t> add_layer(
258+
std::unordered_map<TensorSlotName, tensor_guid_t> add_layer(
259259
LayerAttrs const &layer,
260-
std::vector<tensor_guid_t> const &inputs,
261-
std::vector<InitializerAttrs> const &weights,
262-
std::optional<std::vector<CreateGrad>> const &outputs = std::nullopt);
260+
std::unordered_map<TensorSlotName, tensor_guid_t> const &inputs,
261+
std::unordered_map<TensorSlotName, InitializerAttrs> const &weights,
262+
std::optional<std::unordered_map<TensorSlotName, CreateGrad>> const &outputs = std::nullopt);
263263

264264
tensor_guid_t
265265
broadcast(tensor_guid_t const &, TensorDims const &, std::string const &);

lib/pcg/include/pcg/file_format/v1/graphs/v1_dataflow_graph.h

Lines changed: 0 additions & 15 deletions
This file was deleted.

lib/pcg/include/pcg/file_format/v1/graphs/v1_graph_edge.dtg.toml

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@ features = [
1010
"fmt",
1111
]
1212

13+
template_params = [
14+
"SlotName",
15+
]
16+
1317
includes = [
1418
"utils/nonnegative_int/nonnegative_int.h",
1519
]
@@ -19,13 +23,13 @@ name = "srcNode"
1923
type = "::FlexFlow::nonnegative_int"
2024

2125
[[fields]]
22-
name = "srcIdx"
23-
type = "::FlexFlow::nonnegative_int"
26+
name = "srcSlot"
27+
type = "SlotName"
2428

2529
[[fields]]
2630
name = "dstNode"
2731
type = "::FlexFlow::nonnegative_int"
2832

2933
[[fields]]
30-
name = "dstIdx"
31-
type = "::FlexFlow::nonnegative_int"
34+
name = "dstSlot"
35+
type = "SlotName"

lib/pcg/include/pcg/file_format/v1/graphs/v1_dataflow_graph.dtg.toml renamed to lib/pcg/include/pcg/file_format/v1/graphs/v1_kwarg_dataflow_graph.dtg.toml

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
namespace = "FlexFlow"
2-
name = "V1DataflowGraph"
2+
name = "V1KwargDataflowGraph"
33
type = "struct"
44
features = [
55
"eq",
@@ -10,6 +10,10 @@ features = [
1010
"fmt",
1111
]
1212

13+
template_params = [
14+
"SlotName",
15+
]
16+
1317
includes = [
1418
"<vector>",
1519
"<unordered_set>",
@@ -30,4 +34,4 @@ type = "std::vector<::FlexFlow::nonnegative_int>"
3034

3135
[[fields]]
3236
name = "edges"
33-
type = "std::unordered_set<::FlexFlow::V1GraphEdge>"
37+
type = "std::unordered_set<::FlexFlow::V1GraphEdge<SlotName>>"
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_FILE_FORMAT_V1_GRAPHS_V1_KWARG_DATAFLOW_GRAPH_H
2+
#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_FILE_FORMAT_V1_GRAPHS_V1_KWARG_DATAFLOW_GRAPH_H
3+
4+
#include "pcg/file_format/v1/graphs/v1_kwarg_dataflow_graph.dtg.h"
5+
#include "utils/graph/kwarg_dataflow_graph/kwarg_dataflow_graph_view.h"
6+
#include "utils/bidict/algorithms/bidict_from_enumerating.h"
7+
#include "utils/containers/enumerate.h"
8+
#include "utils/containers/sorted.h"
9+
#include "utils/containers/values.h"
10+
#include "utils/graph/node/algorithms.h"
11+
#include "utils/integer_conversions.h"
12+
#include "utils/graph/kwarg_dataflow_graph/algorithms/get_all_kwarg_dataflow_edges.h"
13+
14+
namespace FlexFlow {
15+
16+
template <typename SlotName>
17+
V1KwargDataflowGraph<SlotName> to_v1(KwargDataflowGraphView<SlotName> const &g) {
18+
bidict<nonnegative_int, Node> node_enumeration_bidict =
19+
bidict_from_enumerating(get_nodes(g));
20+
std::unordered_map<Node, nonnegative_int> node_enumeration =
21+
node_enumeration_bidict.reversed().as_unordered_map();
22+
return to_v1(g, node_enumeration);
23+
}
24+
25+
template <typename SlotName>
26+
V1KwargDataflowGraph<SlotName> to_v1(KwargDataflowGraphView<SlotName> const &g,
27+
std::unordered_map<Node, nonnegative_int> const &nodes) {
28+
std::unordered_set<V1GraphEdge<SlotName>> edges;
29+
for (KwargDataflowEdge<SlotName> const &e : get_all_kwarg_dataflow_edges(g)) {
30+
edges.insert(V1GraphEdge{
31+
nodes.at(e.src.node), e.src.slot_name, nodes.at(e.dst.node), e.dst.slot_name});
32+
}
33+
34+
return V1KwargDataflowGraph<SlotName>{
35+
sorted(values(nodes)),
36+
edges,
37+
};
38+
}
39+
40+
} // namespace FlexFlow
41+
42+
#endif

lib/pcg/include/pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.h

Lines changed: 0 additions & 49 deletions
This file was deleted.

lib/pcg/include/pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.dtg.toml renamed to lib/pcg/include/pcg/file_format/v1/graphs/v1_labelled_kwarg_dataflow_graph.dtg.toml

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
namespace = "FlexFlow"
2-
name = "V1LabelledDataflowGraph"
2+
name = "V1LabelledKwargDataflowGraph"
33
type = "struct"
44
features = [
55
"eq",
@@ -13,11 +13,12 @@ features = [
1313
template_params = [
1414
"NodeLabel",
1515
"OutputLabel",
16+
"SlotName",
1617
]
1718

1819
includes = [
1920
"<unordered_map>",
20-
"pcg/file_format/v1/graphs/v1_dataflow_graph.dtg.h",
21+
"pcg/file_format/v1/graphs/v1_kwarg_dataflow_graph.dtg.h",
2122
"pcg/file_format/v1/graphs/v1_graph_output.dtg.h",
2223
"utils/nonnegative_int/nonnegative_int.h",
2324
]
@@ -35,8 +36,8 @@ type = "std::unordered_map<::FlexFlow::nonnegative_int, NodeLabel>"
3536

3637
[[fields]]
3738
name = "output_labels"
38-
type = "std::unordered_map<::FlexFlow::nonnegative_int, std::vector<OutputLabel>>"
39+
type = "std::unordered_map<::FlexFlow::nonnegative_int, std::unordered_map<SlotName, OutputLabel>>"
3940

4041
[[fields]]
4142
name = "graph"
42-
type = "::FlexFlow::V1DataflowGraph"
43+
type = "::FlexFlow::V1KwargDataflowGraph<SlotName>"

0 commit comments

Comments
 (0)