Skip to content

Commit 2b82bee

Browse files
committed
task-spec reorganization
1 parent 41c68d7 commit 2b82bee

File tree

185 files changed

+1051
-776
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

185 files changed

+1051
-776
lines changed

.flake/pkgs/ffdb/ffdb.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
from proj.config_file import get_config_root
1+
from proj import get_repo_root
22
from pathlib import Path
33
import gdb
44

5-
gdb.execute(f'directory {get_config_root(Path.cwd())}')
5+
gdb.execute(f'directory {get_repo_root(Path.cwd())}')
66
gdb.prompt_hook = lambda x: '(ffdb) '
77
gdb.execute('set history save on')
88
gdb.execute('catch throw')

flake.lock

Lines changed: 6 additions & 6 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

flake.nix

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@
1818
flake-utils.url = "github:numtide/flake-utils";
1919

2020
proj-repo = {
21-
url = "git+https://git.sr.ht/~lockshaw/proj?ref=emulated-fs";
21+
# url = "git+https://git.sr.ht/~lockshaw/proj?ref=emulated-fs";
22+
url = "git+file:///home/lockshaw/x/proj/proj?ref=emulated-fs";
2223
inputs.nixpkgs.follows = "nixpkgs";
2324
inputs.flake-utils.follows = "flake-utils";
2425
};

lib/compiler/include/compiler/graph_optimize_state.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,6 @@ struct GraphOptimizeState {
1717
bool operator<(GraphOptimizeState const &other) const;
1818
};
1919

20-
std::string format_as(GraphOptimizeState const &);
21-
std::ostream &operator<<(std::ostream &, GraphOptimizeState const &);
22-
2320
} // namespace FlexFlow
2421

2522
namespace std {

lib/compiler/include/compiler/machine_mapping/machine_view.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ MachineView make_1d_machine_view(MachineSpaceCoordinate const &start,
5959
MachineSpecificationDimension const &dim,
6060
stride_t stride);
6161

62+
MachineView make_single_device_machine_view(MachineSpaceCoordinate const &);
63+
6264
OperatorAtomicTaskShardBinding
6365
operator_atomic_task_shard_binding_from_machine_view(ComputationGraphOpAttrs const &,
6466
std::vector<ParallelTensorDimDegrees> const &,

lib/compiler/src/compiler/machine_mapping/machine_mapping.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@ MappedParallelComputationGraph
1212
ParallelComputationGraph const &pcg,
1313
MachineMapping const &mapping) {
1414

15+
std::unordered_set<parallel_layer_guid_t> pcg_layers = get_parallel_layers(pcg);
16+
std::unordered_set<parallel_layer_guid_t> mapped_layers = keys(mapping.machine_views);
17+
ASSERT(pcg_layers == mapped_layers);
18+
1519
return MappedParallelComputationGraph{
1620
/*pcg=*/pcg,
1721
/*mapped_tasks=*/
@@ -24,6 +28,7 @@ MappedParallelComputationGraph
2428
std::vector<ParallelTensorDimDegrees> inputs_dim_degrees =
2529
get_incoming_input_degrees(pcg, l);
2630

31+
ASSERT(contains_key(mapping.machine_views, l));
2732
MachineView machine_view = mapping.machine_views.at(l);
2833

2934
return mapped_operator_task_group_from_machine_view(

lib/compiler/src/compiler/machine_mapping/machine_view.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,11 @@ MachineView make_1d_machine_view(MachineSpaceCoordinate const &start,
193193
start, {stride}, {dim});
194194
}
195195

196+
MachineView make_single_device_machine_view(MachineSpaceCoordinate const &coord) {
197+
return machine_view_from_strides_and_machine_spec_dimensions(
198+
coord, {}, {});
199+
}
200+
196201
OperatorAtomicTaskShardBinding
197202
operator_atomic_task_shard_binding_from_machine_view(ComputationGraphOpAttrs const &op_attrs,
198203
std::vector<ParallelTensorDimDegrees> const &inputs_dim_degrees,

lib/compiler/test/src/compiler/graph_optimize_state.cc

Lines changed: 33 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,17 @@
11
#include "compiler/graph_optimize_state.h"
22
#include "compiler/machine_mapping/machine_mapping.dtg.h"
3-
#include "compiler/mapped_parallel_computation_graph.h"
3+
#include "compiler/machine_mapping/machine_mapping.h"
4+
#include "compiler/machine_mapping/machine_view.h"
5+
#include "pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.h"
46
#include "compiler/machine_mapping/machine_view.dtg.h"
57
#include "pcg/parallel_computation_graph/parallel_computation_graph_builder.h"
68
#include <doctest/doctest.h>
9+
#include "test/utils/doctest/check_without_stringify.h"
710

811
using namespace FlexFlow;
912

1013
TEST_SUITE(FF_TEST_SUITE) {
11-
TEST_CASE("GraphOptimizeState::operator==") {
14+
TEST_CASE("GraphOptimizeState operator==") {
1215
TensorShape input_shape = TensorShape{
1316
TensorDims{
1417
FFOrdered{
@@ -19,13 +22,6 @@ TEST_SUITE(FF_TEST_SUITE) {
1922
DataType::FLOAT,
2023
};
2124

22-
// `machine_mapping` is determined by the PCG and the device mapping
23-
// algorithm, and `runtime` is determined by the PCG and the device mapping,
24-
// so their values here do not matter.
25-
MachineMapping empty_machine_mapping = MachineMapping{
26-
std::unordered_map<parallel_layer_guid_t, MachineView>{},
27-
};
28-
2925
InitializerAttrs zero_init = InitializerAttrs{ZeroInitializerAttrs{}};
3026

3127
auto create_pcg = [&]() -> ParallelComputationGraph {
@@ -56,26 +52,45 @@ TEST_SUITE(FF_TEST_SUITE) {
5652
return builder.pcg;
5753
};
5854

55+
auto create_machine_mapping_for_pcg = [](ParallelComputationGraph const &pcg) -> MachineMapping {
56+
MachineSpaceCoordinate device = MachineSpaceCoordinate{
57+
/*node_idx=*/0_n,
58+
/*device_idx=*/0_n,
59+
/*device_type=*/DeviceType::GPU,
60+
};
61+
62+
MachineView machine_view = make_single_device_machine_view(device);
63+
64+
return MachineMapping{
65+
generate_map(get_parallel_layers(pcg),
66+
[&](parallel_layer_guid_t) {
67+
return machine_view;
68+
}),
69+
};
70+
};
71+
5972
ParallelComputationGraph pcg1 = create_pcg();
73+
MachineMapping machine_mapping_1 = create_machine_mapping_for_pcg(pcg1);
6074

6175
SUBCASE("returns true if the PCGs are isomorphic") {
6276
ParallelComputationGraph pcg2 = create_pcg();
77+
MachineMapping machine_mapping_2 = create_machine_mapping_for_pcg(pcg2);
6378

6479
GraphOptimizeState state1 = GraphOptimizeState{
6580
GraphOptimizeResult{
66-
mapped_pcg_from_pcg_and_mapping(pcg1, empty_machine_mapping),
81+
mapped_pcg_from_pcg_and_mapping(pcg1, machine_mapping_1),
6782
},
6883
0,
6984
};
7085

7186
GraphOptimizeState state2 = GraphOptimizeState{
7287
GraphOptimizeResult{
73-
mapped_pcg_from_pcg_and_mapping(pcg2, empty_machine_mapping),
88+
mapped_pcg_from_pcg_and_mapping(pcg2, machine_mapping_2),
7489
},
7590
0,
7691
};
7792

78-
CHECK(state1 == state2);
93+
CHECK_WITHOUT_STRINGIFY(state1 == state2);
7994
}
8095

8196
SUBCASE("returns false it the PCGs are not isomorphic") {
@@ -93,23 +108,25 @@ TEST_SUITE(FF_TEST_SUITE) {
93108
/*bias_initializer=*/zero_init,
94109
/*name=*/"dense0");
95110

96-
ParallelComputationGraph pcg_ = builder_.pcg;
111+
ParallelComputationGraph other_pcg = builder_.pcg;
112+
113+
MachineMapping other_machine_mapping = create_machine_mapping_for_pcg(other_pcg);
97114

98115
GraphOptimizeState state1 = GraphOptimizeState{
99116
GraphOptimizeResult{
100-
mapped_pcg_from_pcg_and_mapping(pcg1, empty_machine_mapping),
117+
mapped_pcg_from_pcg_and_mapping(pcg1, machine_mapping_1),
101118
},
102119
0,
103120
};
104121

105122
GraphOptimizeState state_ = GraphOptimizeState{
106123
GraphOptimizeResult{
107-
mapped_pcg_from_pcg_and_mapping(pcg_, empty_machine_mapping),
124+
mapped_pcg_from_pcg_and_mapping(other_pcg, other_machine_mapping),
108125
},
109126
0,
110127
};
111128

112-
CHECK_FALSE(state1 == state_);
129+
CHECK_FALSE_WITHOUT_STRINGIFY(state1 == state_);
113130
}
114131
}
115132
}

lib/compiler/test/src/compiler/task_graph_simulator/task_simulator.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,12 @@
1515
#include "pcg/machine_specification_dimension.dtg.h"
1616
#include "compiler/machine_mapping/machine_view.dtg.h"
1717
#include "compiler/machine_mapping/machine_view.h"
18-
#include "pcg/machine_view_dimension.dtg.h"
18+
#include "compiler/machine_mapping/machine_view_dimension.dtg.h"
1919
#include "pcg/parallel_computation_graph/parallel_computation_graph.h"
2020
#include "pcg/parallel_computation_graph/parallel_computation_graph_builder.h"
2121
#include "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h"
2222
#include "pcg/parallel_computation_graph/parallel_tensor_guid_t.h"
23-
#include "pcg/stride_t.dtg.h"
23+
#include "compiler/machine_mapping/stride_t.dtg.h"
2424
#include "substitutions/sub_parallel_computation_graph.dtg.h"
2525
#include "substitutions/sub_parallel_computation_graph.h"
2626
#include "utils/containers/get_only.h"

lib/kernels/include/kernels/accessor.h

Lines changed: 16 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,7 @@ class GenericTensorAccessorR {
6767
decltype(device_type) const &>
6868
tie() const;
6969

70-
// TODO(@lockshaw)(#pr): delete
71-
// friend ::std::hash<GenericTensorAccessorR>;
70+
friend ::std::hash<GenericTensorAccessorR>;
7271
};
7372

7473
std::string format_as(GenericTensorAccessorR const &);
@@ -137,8 +136,7 @@ class GenericTensorAccessorW {
137136
decltype(device_type) const &>
138137
tie() const;
139138

140-
// TODO(@lockshaw)(#pr): delete
141-
// friend ::std::hash<GenericTensorAccessorW>;
139+
friend ::std::hash<GenericTensorAccessorW>;
142140
};
143141

144142
std::string format_as(GenericTensorAccessorW const &);
@@ -228,19 +226,18 @@ real_type_t<DT> accessor_get_only_value(GenericTensorAccessorR const &acc) {
228226

229227
} // namespace FlexFlow
230228

231-
// TODO(@lockshaw)(#pr): delete
232-
// namespace std {
233-
//
234-
// template <>
235-
// struct hash<::FlexFlow::GenericTensorAccessorR> {
236-
// size_t operator()(::FlexFlow::GenericTensorAccessorR const &) const;
237-
// };
238-
//
239-
// template <>
240-
// struct hash<::FlexFlow::GenericTensorAccessorW> {
241-
// size_t operator()(::FlexFlow::GenericTensorAccessorW const &) const;
242-
// };
243-
//
244-
// }
245-
//
229+
namespace std {
230+
231+
template <>
232+
struct hash<::FlexFlow::GenericTensorAccessorR> {
233+
size_t operator()(::FlexFlow::GenericTensorAccessorR const &) const;
234+
};
235+
236+
template <>
237+
struct hash<::FlexFlow::GenericTensorAccessorW> {
238+
size_t operator()(::FlexFlow::GenericTensorAccessorW const &) const;
239+
};
240+
241+
}
242+
246243
#endif

0 commit comments

Comments
 (0)