Skip to content

Commit 7aee760

Browse files
committed
Pass KwargDataflowGraph tests
1 parent f6c9a0e commit 7aee760

File tree

45 files changed

+1038
-19
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

45 files changed

+1038
-19
lines changed

lib/pcg/src/pcg/computation_graph.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#include "op-attrs/computation_graph_op_attrs.h"
33
#include "op-attrs/get_incoming_tensor_roles.h"
44
#include "op-attrs/shape_inference.h"
5+
#include "utils/graph/instances/unordered_set_labelled_open_kwarg_dataflow_graph.h"
56
#include "utils/singular_or_variadic.h"
67
#include "utils/containers/binary_merge_disjoint_maps.h"
78
#include "utils/containers/concat_vectors.h"
@@ -32,7 +33,7 @@ namespace FlexFlow {
3233
ComputationGraph make_empty_computation_graph() {
3334
return ComputationGraph{
3435
LabelledKwargDataflowGraph<LayerAttrs, TensorAttrs, TensorSlotName>::create<
35-
UnorderedSetLabelledOpenDataflowGraph<LayerAttrs, TensorAttrs>>()};
36+
UnorderedSetLabelledOpenKwargDataflowGraph<LayerAttrs, TensorAttrs>>()};
3637
}
3738

3839
std::unordered_set<layer_guid_t> get_layers(ComputationGraph const &cg) {

lib/utils/include/utils/containers/transform.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include "utils/containers/vector_transform.h"
55
#include "utils/required_core.h"
66
#include <algorithm>
7+
#include <map>
78
#include <optional>
89
#include <set>
910
#include <type_traits>
@@ -80,6 +81,21 @@ std::unordered_map<K2, V2> transform(std::unordered_map<K, V> const &m,
8081
return result;
8182
}
8283

84+
template <typename K,
85+
typename V,
86+
typename F,
87+
typename K2 = typename std::invoke_result_t<F, K, V>::first_type,
88+
typename V2 = typename std::invoke_result_t<F, K, V>::second_type>
89+
std::unordered_map<K2, V2> transform(std::map<K, V> const &m,
90+
F const &f) {
91+
std::unordered_map<K2, V2> result;
92+
for (auto const &[k, v] : m) {
93+
result.insert(f(k, v));
94+
}
95+
return result;
96+
}
97+
98+
8399
template <typename F, typename T>
84100
std::optional<std::invoke_result_t<F, T>> transform(std::optional<T> const &o,
85101
F &&f) {

lib/utils/include/utils/exception.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ namespace FlexFlow {
1616
":" __LINE__);
1717
#else
1818
#define NOT_IMPLEMENTED() \
19-
throw ::FlexFlow::not_implemented(__PRETTY_FUNCTION__, __FILE__, __LINE__);
19+
PANIC("Not implemented");
2020
#endif
2121

2222
class not_implemented : public std::logic_error {
Lines changed: 304 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,304 @@
1+
#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_INSTANCES_TASK_SET_OPEN_KWARG_DATAFLOW_GRAPH_H
2+
#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_INSTANCES_TASK_SET_OPEN_KWARG_DATAFLOW_GRAPH_H
3+
4+
#include "utils/graph/kwarg_dataflow_graph/kwarg_node_added_result.dtg.h"
5+
#include "utils/graph/labelled_open_kwarg_dataflow_graph/i_labelled_open_kwarg_dataflow_graph_view.h"
6+
#include "utils/graph/labelled_open_kwarg_dataflow_graph/i_labelled_open_kwarg_dataflow_graph.h"
7+
#include "utils/graph/node/node_source.h"
8+
#include "utils/overload.h"
9+
#include "utils/graph/kwarg_dataflow_graph/algorithms/get_all_kwarg_dataflow_outputs.h"
10+
#include "utils/graph/open_kwarg_dataflow_graph/algorithms/get_all_open_kwarg_dataflow_edges.h"
11+
#include "utils/graph/kwarg_dataflow_graph/algorithms/get_all_kwarg_dataflow_edges.h"
12+
#include "utils/containers/contains_key.h"
13+
#include "utils/graph/node/algorithms.h"
14+
#include "utils/containers/generate_map.h"
15+
#include "utils/graph/open_kwarg_dataflow_graph/algorithms/get_all_kwarg_dataflow_graph_inputs.h"
16+
#include "utils/containers/map_values.h"
17+
#include "utils/singular_or_variadic.h"
18+
#include "utils/graph/open_kwarg_dataflow_graph/open_kwarg_dataflow_edge.h"
19+
#include "utils/containers/extend.h"
20+
#include "utils/containers/enumerate.h"
21+
22+
namespace FlexFlow {
23+
24+
template <typename NodeLabel,
25+
typename ValueLabel,
26+
typename GraphInputName,
27+
typename SlotName>
28+
struct UnorderedSetLabelledOpenKwargDataflowGraph
29+
: public ILabelledOpenKwargDataflowGraph<NodeLabel, ValueLabel, GraphInputName, SlotName>
30+
, public ILabelledKwargDataflowGraph<NodeLabel, ValueLabel, SlotName>
31+
{
32+
public:
33+
UnorderedSetLabelledOpenKwargDataflowGraph() = default;
34+
35+
KwargNodeAddedResult<SlotName>
36+
add_node(NodeLabel const &node_label,
37+
std::unordered_map<SlotName, SingularOrVariadic<KwargDataflowOutput<SlotName>>> const &inputs,
38+
std::unordered_map<SlotName, SingularOrVariadic<ValueLabel>> const &output_labels) override {
39+
return this->add_node(
40+
node_label,
41+
map_values(inputs,
42+
[](SingularOrVariadic<KwargDataflowOutput<SlotName>> const &input) {
43+
return transform_singular_or_variadic(
44+
input,
45+
[](KwargDataflowOutput<SlotName> const &o) {
46+
return OpenKwargDataflowValue<GraphInputName, SlotName>{o};
47+
});
48+
}),
49+
output_labels);
50+
};
51+
52+
KwargNodeAddedResult<SlotName> add_node(
53+
NodeLabel const &node_label,
54+
std::unordered_map<SlotName, SingularOrVariadic<OpenKwargDataflowValue<GraphInputName, SlotName>>> const &inputs,
55+
std::unordered_map<SlotName, SingularOrVariadic<ValueLabel>> const &output_labels) override
56+
{
57+
Node new_node = this->node_source.new_node();
58+
this->nodes.insert({new_node, node_label});
59+
60+
for (auto const &[input_slot_name, input_val] : inputs) {
61+
KwargDataflowInput<SlotName> dst = KwargDataflowInput<SlotName>{
62+
new_node,
63+
input_slot_name,
64+
};
65+
66+
auto mk_edge_from = [&](OpenKwargDataflowValue<GraphInputName, SlotName> const &src) {
67+
return mk_open_kwarg_dataflow_edge_from_src_val_and_dst(src, dst);
68+
};
69+
70+
std::vector<OpenKwargDataflowEdge<GraphInputName, SlotName>> in_edges = input_val.template visit<
71+
std::vector<OpenKwargDataflowEdge<GraphInputName, SlotName>>
72+
>(overload {
73+
[&](OpenKwargDataflowValue<GraphInputName, SlotName> const &singular_value) {
74+
return std::vector{
75+
mk_edge_from(singular_value),
76+
};
77+
},
78+
[&](std::vector<OpenKwargDataflowValue<GraphInputName, SlotName>> const &variadic_values) {
79+
return transform(variadic_values, mk_edge_from);
80+
}
81+
});
82+
83+
extend(this->edges, in_edges);
84+
}
85+
86+
auto mk_singular_output = [&](SlotName const &slot_name, ValueLabel const &value_label)
87+
-> KwargDataflowOutput<SlotName>
88+
{
89+
KwargDataflowOutput<SlotName> output = KwargDataflowOutput<SlotName>{
90+
/*node=*/new_node,
91+
/*value_ref=*/SlotValueReference{slot_name},
92+
};
93+
94+
this->outputs.insert({
95+
output,
96+
value_label,
97+
});
98+
99+
return output;
100+
};
101+
102+
auto mk_variadic_output = [&](SlotName const &slot_name, std::vector<ValueLabel> const &value_labels)
103+
-> std::vector<KwargDataflowOutput<SlotName>>
104+
{
105+
return transform(vector_of(enumerate(value_labels)),
106+
[&](std::pair<nonnegative_int, ValueLabel> const &entry) -> KwargDataflowOutput<SlotName> {
107+
nonnegative_int entry_idx = entry.first;
108+
ValueLabel entry_value_label = entry.second;
109+
110+
KwargDataflowOutput<SlotName> output = KwargDataflowOutput<SlotName>{
111+
/*node=*/new_node,
112+
/*value_ref=*/SlotValueReference<SlotName>{
113+
VariadicSlotValueReference<SlotName>{
114+
slot_name,
115+
entry_idx,
116+
},
117+
},
118+
};
119+
120+
this->outputs.insert({
121+
output,
122+
entry_value_label,
123+
});
124+
125+
return output;
126+
});
127+
};
128+
129+
auto mk_singular_or_variadic_output = [&](
130+
SlotName const &slot_name, SingularOrVariadic<ValueLabel> const &value_label)
131+
-> SingularOrVariadic<KwargDataflowOutput<SlotName>>
132+
{
133+
return value_label.template visit<
134+
SingularOrVariadic<KwargDataflowOutput<SlotName>>
135+
>(overload {
136+
[&](ValueLabel const &singular_value_label) {
137+
return SingularOrVariadic{
138+
mk_singular_output(slot_name, singular_value_label),
139+
};
140+
},
141+
[&](std::vector<ValueLabel> const &variadic_value_labels) {
142+
return SingularOrVariadic{
143+
mk_variadic_output(slot_name, variadic_value_labels),
144+
};
145+
}
146+
});
147+
};
148+
149+
std::unordered_map<SlotName, SingularOrVariadic<KwargDataflowOutput<SlotName>>> outputs =
150+
generate_map(keys(output_labels),
151+
[&](SlotName const &output_slot)
152+
-> SingularOrVariadic<KwargDataflowOutput<SlotName>>
153+
{
154+
SingularOrVariadic<ValueLabel> value_labels = output_labels.at(output_slot);
155+
156+
return mk_singular_or_variadic_output(output_slot, value_labels);
157+
});
158+
159+
return KwargNodeAddedResult<SlotName>{
160+
/*node=*/new_node,
161+
/*outputs=*/outputs,
162+
};
163+
}
164+
165+
KwargDataflowGraphInput<GraphInputName> add_input(
166+
GraphInputName const &name, ValueLabel const &value_label) override
167+
{
168+
KwargDataflowGraphInput<GraphInputName> input
169+
= KwargDataflowGraphInput{name};
170+
171+
ASSERT(!contains_key(this->graph_inputs, input));
172+
this->graph_inputs.insert({input, value_label});
173+
174+
return input;
175+
}
176+
177+
std::unordered_set<Node> query_nodes(NodeQuery const &q) const override {
178+
return filter(keys(this->nodes),
179+
[&](Node const &n) { return includes(q.nodes, n); });
180+
}
181+
182+
std::unordered_set<OpenKwargDataflowEdge<GraphInputName, SlotName>>
183+
query_edges(OpenKwargDataflowEdgeQuery<GraphInputName, SlotName> const &q) const override {
184+
return filter(this->edges,
185+
[&](OpenKwargDataflowEdge<GraphInputName, SlotName> const &e) {
186+
return open_kwarg_dataflow_edge_query_includes(q, e);
187+
});
188+
}
189+
190+
std::unordered_set<KwargDataflowOutput<SlotName>>
191+
query_outputs(KwargDataflowOutputQuery<SlotName> const &q) const override {
192+
return filter(keys(this->outputs),
193+
[&](KwargDataflowOutput<SlotName> const &output) {
194+
return kwarg_dataflow_output_query_includes(q, output);
195+
});
196+
}
197+
198+
std::unordered_set<KwargDataflowGraphInput<GraphInputName>> get_inputs() const override {
199+
return keys(this->graph_inputs);
200+
}
201+
202+
NodeLabel at(Node const &n) const override {
203+
return this->nodes.at(n);
204+
}
205+
206+
ValueLabel at(OpenKwargDataflowValue<GraphInputName, SlotName> const &v) const override {
207+
return v.template visit<ValueLabel>(overload {
208+
[&](KwargDataflowOutput<SlotName> const &o) -> ValueLabel {
209+
return this->outputs.at(o);
210+
},
211+
[&](KwargDataflowGraphInput<GraphInputName> const &gi) -> ValueLabel {
212+
return this->graph_inputs.at(gi);
213+
}
214+
});
215+
}
216+
217+
void inplace_materialize_from(
218+
LabelledKwargDataflowGraphView<NodeLabel, ValueLabel, SlotName> const &view) override {
219+
std::unordered_set<Node> view_nodes = get_nodes(view);
220+
std::unordered_set<KwargDataflowEdge<SlotName>> view_edges = get_all_kwarg_dataflow_edges(view);
221+
std::unordered_set<KwargDataflowOutput<SlotName>> view_outputs
222+
= get_all_kwarg_dataflow_outputs(view);
223+
224+
this->graph_inputs.clear();
225+
this->nodes = generate_map(view_nodes,
226+
[&](Node const &n) {
227+
return view.at(n);
228+
});
229+
230+
this->edges = transform(view_edges,
231+
[&](KwargDataflowEdge<SlotName> const &e)
232+
-> OpenKwargDataflowEdge<GraphInputName, SlotName>
233+
{
234+
return OpenKwargDataflowEdge<GraphInputName, SlotName>{e};
235+
});
236+
this->outputs = generate_map(view_outputs,
237+
[&](KwargDataflowOutput<SlotName> const &o) {
238+
return view.at(o);
239+
});
240+
}
241+
242+
void inplace_materialize_from(
243+
LabelledOpenKwargDataflowGraphView<NodeLabel, ValueLabel, GraphInputName, SlotName> const &view) override
244+
{
245+
std::unordered_set<KwargDataflowGraphInput<GraphInputName>> view_inputs = get_all_kwarg_dataflow_graph_inputs(view);
246+
std::unordered_set<Node> view_nodes = get_nodes(view);
247+
std::unordered_set<OpenKwargDataflowEdge<GraphInputName, SlotName>> view_edges
248+
= get_all_open_kwarg_dataflow_edges(view);
249+
std::unordered_set<KwargDataflowOutput<SlotName>> view_outputs
250+
= get_all_kwarg_dataflow_outputs(view);
251+
252+
this->graph_inputs = generate_map(view_inputs,
253+
[&](KwargDataflowGraphInput<GraphInputName> const &i) {
254+
return view.at(
255+
OpenKwargDataflowValue<GraphInputName, SlotName>{i}
256+
);
257+
});
258+
this->nodes = generate_map(view_nodes,
259+
[&](Node const &n) {
260+
return view.at(n);
261+
});
262+
263+
this->edges = view_edges;
264+
this->outputs = generate_map(view_outputs,
265+
[&](KwargDataflowOutput<SlotName> const &o) {
266+
return view.at(
267+
OpenKwargDataflowValue<GraphInputName, SlotName>{o}
268+
);
269+
});
270+
}
271+
272+
UnorderedSetLabelledOpenKwargDataflowGraph *clone() const override {
273+
return new UnorderedSetLabelledOpenKwargDataflowGraph{
274+
this->node_source,
275+
this->graph_inputs,
276+
this->nodes,
277+
this->edges,
278+
this->outputs,
279+
};
280+
}
281+
282+
private:
283+
UnorderedSetLabelledOpenKwargDataflowGraph(
284+
NodeSource const &node_source,
285+
std::unordered_map<KwargDataflowGraphInput<GraphInputName>, ValueLabel> const &graph_inputs,
286+
std::unordered_map<Node, NodeLabel> const &nodes,
287+
std::unordered_set<OpenKwargDataflowEdge<GraphInputName, SlotName>> const &edges,
288+
std::unordered_map<KwargDataflowOutput<SlotName>, ValueLabel> const &outputs)
289+
: node_source(node_source), graph_inputs(graph_inputs), nodes(nodes), edges(edges), outputs(outputs)
290+
{ }
291+
292+
293+
private:
294+
NodeSource node_source;
295+
296+
std::unordered_map<KwargDataflowGraphInput<GraphInputName>, ValueLabel> graph_inputs;
297+
std::unordered_map<Node, NodeLabel> nodes;
298+
std::unordered_set<OpenKwargDataflowEdge<GraphInputName, SlotName>> edges;
299+
std::unordered_map<KwargDataflowOutput<SlotName>, ValueLabel> outputs;
300+
};
301+
302+
} // namespace FlexFlow
303+
304+
#endif
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_KWARG_DATAFLOW_GRAPH_ALGORITHMS_GET_ALL_KWARG_DATAFLOW_EDGES_H
2+
#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_KWARG_DATAFLOW_GRAPH_ALGORITHMS_GET_ALL_KWARG_DATAFLOW_EDGES_H
3+
4+
#include "utils/graph/kwarg_dataflow_graph/kwarg_dataflow_edge_query.h"
5+
#include "utils/graph/kwarg_dataflow_graph/kwarg_dataflow_graph_view.h"
6+
7+
namespace FlexFlow {
8+
9+
template <typename SlotName>
10+
std::unordered_set<KwargDataflowEdge<SlotName>> get_all_kwarg_dataflow_edges(
11+
KwargDataflowGraphView<SlotName> const &g)
12+
{
13+
return g.query_edges(kwarg_dataflow_edge_query_all<SlotName>());
14+
}
15+
16+
} // namespace FlexFlow
17+
18+
#endif
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_KWARG_DATAFLOW_GRAPH_ALGORITHMS_GET_ALL_KWARG_DATAFLOW_OUTPUTS_H
2+
#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_KWARG_DATAFLOW_GRAPH_ALGORITHMS_GET_ALL_KWARG_DATAFLOW_OUTPUTS_H
3+
4+
#include "utils/graph/kwarg_dataflow_graph/kwarg_dataflow_graph_view.h"
5+
#include "utils/graph/kwarg_dataflow_graph/kwarg_dataflow_output_query.h"
6+
7+
namespace FlexFlow {
8+
9+
template <typename SlotName>
10+
std::unordered_set<KwargDataflowOutput<SlotName>>
11+
get_all_kwarg_dataflow_outputs(
12+
KwargDataflowGraphView<SlotName> const &view) {
13+
return view.query_outputs(kwarg_dataflow_output_query_all<SlotName>());
14+
}
15+
16+
} // namespace FlexFlow
17+
18+
#endif

0 commit comments

Comments
 (0)