Skip to content

Commit 5e774c6

Browse files
committed
Fri Dec 27 03:20:32 PM PST 2024
1 parent e635b1f commit 5e774c6

33 files changed

+593
-61
lines changed

lib/kernels/include/kernels/legion_dim.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
namespace FlexFlow {
88

9+
std::set<legion_dim_t> legion_dim_range(int end);
10+
911
legion_dim_t add_to_legion_dim(legion_dim_t legion_dim, int value);
1012

1113
legion_dim_t legion_dim_from_ff_dim(ff_dim_t, int num_dimensions);
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,14 @@
11
#include "kernels/legion_dim.h"
2+
#include "utils/containers/range.h"
3+
#include "utils/containers/transform.h"
4+
#include "utils/containers/set_of.h"
25

36
namespace FlexFlow {
47

8+
std::set<legion_dim_t> legion_dim_range(int end) {
9+
return set_of(transform(range(end), [](int i) { return ff_dim_t{i}; }));
10+
}
11+
512
legion_dim_t add_to_legion_dim(legion_dim_t legion_dim, int value) {
613
return legion_dim_t(legion_dim.value + value);
714
}

lib/op-attrs/include/op-attrs/ff_dim.h

Lines changed: 0 additions & 18 deletions
This file was deleted.
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_FF_DIM_T_H
2+
#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_FF_DIM_T_H
3+
4+
#include "op-attrs/ff_dim.dtg.h"
5+
#include "rapidcheck.h"
6+
7+
namespace FlexFlow {
8+
9+
std::set<ff_dim_t> ff_dim_range(int end);
10+
11+
} // namespace FlexFlow
12+
13+
namespace rc {
14+
template <>
15+
struct Arbitrary<FlexFlow::ff_dim_t> {
16+
static Gen<FlexFlow::ff_dim_t> arbitrary();
17+
};
18+
} // namespace rc
19+
20+
#endif
File renamed without changes.

lib/op-attrs/include/op-attrs/operator_space_parallel_tensor_space_mapping.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@
22
#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPERATOR_SPACE_PARALLEL_TENSOR_SPACE_MAPPING_H
33

44
#include "op-attrs/operator_space_parallel_tensor_space_mapping.dtg.h"
5-
#include "op-attrs/parallel_tensor_dim_degrees.dtg.h"
5+
#include "op-attrs/tensor_num_dims.dtg.h"
66

77
namespace FlexFlow {
88

99
OperatorSpaceParallelTensorSpaceMapping
10-
get_identity_mapping(ParallelTensorDimDegrees const &);
10+
get_identity_mapping(TensorNumDims const &);
1111

1212
} // namespace FlexFlow
1313

lib/op-attrs/include/op-attrs/ops/linear.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,12 @@
66
#include "op-attrs/ops/linear_attrs.dtg.h"
77
#include "op-attrs/parallel_tensor_dim_degrees.dtg.h"
88
#include "op-attrs/parallel_tensor_shape.dtg.h"
9+
#include "op-attrs/tensor_num_dims.dtg.h"
910
#include "op-attrs/tensor_shape.dtg.h"
1011
#include "utils/record_formatter.h"
1112
#include <tl/expected.hpp>
1213
#include "op-attrs/operator_space_parallel_tensor_space_mapping.dtg.h"
14+
#include "op-attrs/parallel_tensor_space_mapping.dtg.h"
1315

1416
namespace FlexFlow {
1517

@@ -27,6 +29,10 @@ tl::expected<TensorShape, std::string> get_bias_shape(LinearAttrs const &attrs,
2729
tl::expected<TensorShape, std::string>
2830
get_output_shape(LinearAttrs const &attrs, TensorShape const &input);
2931

32+
tl::expected<ParallelTensorSpaceMapping, std::string>
33+
get_projection_to_output_parallel_dim_mapping(LinearAttrs const &attrs,
34+
ParallelTensorDimDegrees const &input);
35+
3036
tl::expected<ParallelTensorDimDegrees, std::string>
3137
get_projection_parallel_dim_degrees(LinearAttrs const &attrs, ParallelTensorDimDegrees const &input);
3238
tl::expected<ParallelTensorDimDegrees, std::string>
@@ -51,7 +57,7 @@ tl::expected<OperatorSpaceParallelTensorSpaceMapping, std::string>
5157
ParallelTensorDimDegrees const &input);
5258
tl::expected<OperatorSpaceParallelTensorSpaceMapping, std::string>
5359
get_output_space_mapping(LinearAttrs const &attrs,
54-
ParallelTensorDimDegrees const &input);
60+
TensorNumDims const &input_num_dims);
5561

5662
} // namespace FlexFlow
5763

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_PARALLEL_TENSOR_DIM_IDX_T_H
2+
#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_PARALLEL_TENSOR_DIM_IDX_T_H
3+
4+
#include "op-attrs/parallel_tensor_dim_idx_t.dtg.h"
5+
6+
namespace FlexFlow {
7+
8+
parallel_tensor_dim_idx_t sum_dim_idx();
9+
parallel_tensor_dim_idx_t discard_copy_dim_idx();
10+
parallel_tensor_dim_idx_t shard_dim_idx(ff_dim_t);
11+
12+
} // namespace FlexFlow
13+
14+
#endif
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_PARALLEL_TENSOR_SPACE_MAPPING_H
2+
#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_PARALLEL_TENSOR_SPACE_MAPPING_H
3+
4+
namespace FlexFlow {
5+
6+
} // namespace FlexFlow
7+
8+
#endif
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
namespace = "FlexFlow"
2+
name = "ParallelTensorSpaceMapping"
3+
features = [
4+
"eq",
5+
"hash",
6+
"fmt",
7+
]
8+
9+
includes = [
10+
"utils/orthotope/dim_projection.dtg.h",
11+
"op-attrs/parallel_tensor_dim_idx_t.dtg.h",
12+
]
13+
14+
[[fields]]
15+
name = "raw_projection"
16+
type = "::FlexFlow::DimProjection<::FlexFlow::parallel_tensor_dim_idx_t, ::FlexFlow::parallel_tensor_dim_idx_t>"

0 commit comments

Comments
 (0)