Skip to content

Commit 5f4e0a6

Browse files
committed
Scope handling w/o changes to Tensor
1 parent 9e8bcf7 commit 5f4e0a6

File tree

19 files changed

+238
-99
lines changed

19 files changed

+238
-99
lines changed

cmake/modules/OpenCL.cmake

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,9 @@ if(USE_OPENCL)
8080
message(STATUS "Set OpenCL Target version to " ${CMAKE_MATCH_1})
8181
endif()
8282
endif(USE_OPENCL_EXTN_QCOM)
83+
if(PROFILE_SHADER_DUMP)
84+
add_definitions(-DPROFILE_SHADER_DUMP)
85+
endif(PROFILE_SHADER_DUMP)
8386
else()
8487
list(APPEND COMPILER_SRCS src/target/opt/build_opencl_off.cc)
8588
endif(USE_OPENCL)

include/tvm/relax/exec_builder.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,15 @@ class ExecBuilderNode : public Object {
122122
rv = value;
123123
return ConvertConstant_(rv);
124124
}
125+
/*!
126+
* \brief update memory scopes.
127+
*
128+
* This function builds the memory scopes for constants.
129+
*
130+
* \param Index of the constant
131+
* \param The memory scope.
132+
*/
133+
void SaveMemoryScope(vm::Instruction::Arg idx, ffi::String scope);
125134
/*!
126135
* \brief Raw access to underlying executable build in progress.
127136
*/

include/tvm/runtime/tensor.h

Lines changed: 6 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636

3737
#include <atomic>
3838
#include <functional>
39-
#include <string>
4039
#include <utility>
4140
#include <vector>
4241

@@ -189,25 +188,14 @@ class Tensor : public tvm::ffi::Tensor {
189188
*/
190189
TVM_DLL static void CopyFromBytes(const DLTensor* to, void* from, size_t nbytes,
191190
TVMStreamHandle stream = nullptr);
192-
193-
TVM_DLL void SetScope(ffi::String scope);
194-
TVM_DLL ffi::String GetScope() const;
195-
196-
protected:
197-
/*!
198-
* \brief The memory scope
199-
* represents the underlying scope information of device
200-
*/
201-
ffi::String scope = "global";
202191
};
203192

204193
/*!
205194
* \brief Save a DLTensor to stream
206195
* \param strm The output stream
207196
* \param tensor The tensor to be saved.
208-
* \param scope The tensor storage scope.
209197
*/
210-
inline bool SaveDLTensor(dmlc::Stream* strm, const DLTensor* tensor, ffi::String scope = "global");
198+
inline bool SaveDLTensor(dmlc::Stream* strm, const DLTensor* tensor);
211199

212200
inline void Tensor::CopyFrom(const DLTensor* other) {
213201
ICHECK(data_ != nullptr);
@@ -232,11 +220,10 @@ inline void Tensor::CopyTo(const Tensor& other) const {
232220
}
233221

234222
/*! \brief Magic number for Tensor file */
235-
constexpr uint64_t kTVMNDArrayMagic = 0xDD5E40F096B4A13F;
236-
constexpr uint64_t kTVMNDArrayScopedMagic = 0xDD5E40F096B4A13E;
223+
constexpr uint64_t kTVMTensorMagic = 0xDD5E40F096B4A13F;
237224

238-
inline bool SaveDLTensor(dmlc::Stream* strm, const DLTensor* tensor, ffi::String scope) {
239-
uint64_t header = kTVMNDArrayScopedMagic, reserved = 0;
225+
inline bool SaveDLTensor(dmlc::Stream* strm, const DLTensor* tensor) {
226+
uint64_t header = kTVMTensorMagic, reserved = 0;
240227
strm->Write(header);
241228
strm->Write(reserved);
242229
// Always save data as CPU context
@@ -256,7 +243,6 @@ inline bool SaveDLTensor(dmlc::Stream* strm, const DLTensor* tensor, ffi::String
256243
strm->Write(tensor->dtype);
257244
int ndim = tensor->ndim;
258245
strm->WriteArray(tensor->shape, ndim);
259-
strm->Write(std::string(scope));
260246
int type_bytes = (tensor->dtype.bits + 7) / 8;
261247
int64_t num_elems = 1;
262248
for (int i = 0; i < ndim; ++i) {
@@ -280,14 +266,13 @@ inline bool SaveDLTensor(dmlc::Stream* strm, const DLTensor* tensor, ffi::String
280266
return true;
281267
}
282268

283-
inline void Tensor::Save(dmlc::Stream* strm) const { SaveDLTensor(strm, operator->(), GetScope()); }
269+
inline void Tensor::Save(dmlc::Stream* strm) const { SaveDLTensor(strm, operator->()); }
284270

285271
inline bool Tensor::Load(dmlc::Stream* strm) {
286272
uint64_t header, reserved;
287273
ICHECK(strm->Read(&header)) << "Invalid DLTensor file format";
288274
ICHECK(strm->Read(&reserved)) << "Invalid DLTensor file format";
289-
ICHECK((header == kTVMNDArrayMagic) || (header == kTVMNDArrayScopedMagic))
290-
<< "Invalid DLTensor file format";
275+
ICHECK(header == kTVMTensorMagic) << "Invalid DLTensor file format";
291276
Device dev;
292277
int ndim;
293278
DLDataType dtype;
@@ -305,11 +290,6 @@ inline bool Tensor::Load(dmlc::Stream* strm) {
305290
for (int i = 0; i < ret->ndim; ++i) {
306291
num_elems *= ret->shape[i];
307292
}
308-
if (header == kTVMNDArrayScopedMagic) {
309-
std::string scope;
310-
strm->Read(&scope);
311-
ret.SetScope(scope);
312-
}
313293
int64_t data_byte_size;
314294
ICHECK(strm->Read(&data_byte_size)) << "Invalid DLTensor file format";
315295
ICHECK(data_byte_size == num_elems * elem_bytes) << "Invalid DLTensor file format";

include/tvm/runtime/vm/executable.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,8 @@ class VMExecutable : public ffi::ModuleObj {
155155
std::unordered_map<std::string, Index> func_map;
156156
/*! \brief The global constant pool. */
157157
std::vector<ffi::Any> constants;
158+
/*! \brief The VDevice memory scopes */
159+
std::unordered_map<Index, std::string> memory_scopes;
158160
/*! \brief The offset of instruction. */
159161
std::vector<Index> instr_offset;
160162
/*! \brief The byte data of instruction. */
@@ -177,6 +179,11 @@ class VMExecutable : public ffi::ModuleObj {
177179
* \param strm The input stream.
178180
*/
179181
void SaveGlobalSection(dmlc::Stream* strm) const;
182+
/*!
183+
* \brief Save the memory scopes.
184+
* \param strm The output stream.
185+
*/
186+
void SaveMemoryScopeSection(dmlc::Stream* strm) const;
180187
/*!
181188
* \brief Save the constant pool.
182189
* \param strm The input stream.
@@ -197,6 +204,11 @@ class VMExecutable : public ffi::ModuleObj {
197204
* \param strm The input stream.
198205
*/
199206
void LoadGlobalSection(dmlc::Stream* strm);
207+
/*!
208+
* \brief Load the memory scopes.
209+
* \param strm The input stream.
210+
*/
211+
void LoadMemoryScopeSection(dmlc::Stream* strm);
200212
/*!
201213
* \brief Load the constant pool.
202214
* \param strm The input stream.

python/tvm/dlight/adreno/convolution.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from tvm import tir
2222
from tvm.target import Target
2323

24-
from .utils import schedule_inline_blocks, schedule_storage_annotate, schedule_default
24+
from .utils import schedule_inline_blocks, schedule_default
2525
from .. import analysis
2626
from .base import AdrenoScheduleRule
2727

@@ -102,6 +102,6 @@ def is_convolution(blk):
102102
Conv2d.schedule_conv2d(sch, conv_blk)
103103
remaining_blocks = schedule_inline_blocks(sch, remaining_blocks)
104104
schedule_default(sch, remaining_blocks)
105-
schedule_storage_annotate(sch, remaining_blocks)
105+
#schedule_storage_annotate(sch, remaining_blocks)
106106

107107
return sch

python/tvm/dlight/adreno/fallback.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,22 @@
2626
from .utils import get_texture_storage
2727

2828

29+
def _assert_gpu_target(target: Target):
30+
if "gpu" not in target.keys:
31+
raise ValueError(f"Expect a GPU target, but got {target}")
32+
33+
34+
def get_max_threads_per_block(target: Target) -> int:
35+
_assert_gpu_target(target)
36+
max_threads_per_block = None
37+
for name in ["max_threads_per_block", "max_num_threads"]:
38+
if max_threads_per_block is None:
39+
max_threads_per_block = target.attrs.get(name, None)
40+
if max_threads_per_block is None:
41+
max_threads_per_block = 64
42+
return int(max_threads_per_block)
43+
44+
2945
# pylint: disable=invalid-name,missing-function-docstring,unused-variable,unused-import
3046
class Fallback(AdrenoScheduleRule):
3147
"""Texture Based Fallback Schedule(s) for Adreno"""
@@ -46,12 +62,12 @@ def schedule_inline_blocks(
4662
for blk in blocks:
4763
block_info = analysis.get_block_info(sch, blk)
4864
if block_info.is_injective() and not block_info.is_data_pad(sch):
49-
if len(block_info.consumers) == 1:
65+
if len(sch.get_consumers(blk)) == 1:
5066
try:
5167
sch.compute_inline(blk)
5268
except Exception: # pylint: disable=broad-exception-caught
5369
remaining_blocks.append(blk)
54-
elif len(block_info.producers) == 1:
70+
elif len(sch.get_producers(blk)) == 1:
5571
inlined_once = False
5672
try:
5773
# Would cause an issue inlining to producer with multiple consumers
@@ -76,15 +92,15 @@ def schedule_default(sch: tir.Schedule, blk: tir.schedule.BlockRV):
7692
block_info = analysis.get_block_info(sch, blk)
7793

7894
s_loops, r_loops, o_loops = [], [], []
79-
v_loop = block_info.write_bufs[0].assoc_lps[-1]
95+
v_loop = block_info.write_bufs(sch)[0].assoc_lps[-1]
8096

8197
for iter_info in block_info.iters:
8298
if sch.get(iter_info.loop_rv) == sch.get(v_loop):
8399
continue
84100
{"S": s_loops, "R": r_loops, "O": o_loops}.get(iter_info.kind).append(iter_info.loop_rv)
85101

86102
iter_vars = analysis.collect_block_iter_vars_used_in_access_region(
87-
block_info.block_stmt, block_info.write_bufs[0].buf_region.region
103+
sch.get(blk), block_info.write_bufs(sch)[0].buf_region.region
88104
)
89105
o_outer = [lp for lp in o_loops if sch.get(lp).var in iter_vars]
90106
o_inner = [lp for lp in o_loops if sch.get(lp).var not in iter_vars]
@@ -100,7 +116,7 @@ def schedule_default(sch: tir.Schedule, blk: tir.schedule.BlockRV):
100116
tgt = Target.current(allow_none=True)
101117

102118
b = sch.fuse(*s_loops)
103-
tx_extent = analysis.get_max_threads_per_block(tgt) if tgt is not None else 256
119+
tx_extent = get_max_threads_per_block(tgt) if tgt is not None else 256
104120
bx, tx = sch.split(b, [None, tx_extent])
105121
sch.bind(bx, "blockIdx.x")
106122
sch.bind(tx, "threadIdx.x")
@@ -155,7 +171,7 @@ def apply( # pylint: disable=too-many-locals
155171
return None
156172

157173
block_infos = [analysis.get_block_info(sch, block) for block in blocks]
158-
if not any("texture" in block.write_bufs[0].get_scope() for block in block_infos):
174+
if not any("texture" in block.write_bufs(sch)[0].get_scope() for block in block_infos):
159175
return None
160176

161177
Fallback.schedule_fallback(sch)

python/tvm/dlight/adreno/layout_transform.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def apply( # pylint: disable=too-many-locals
6565
):
6666
return None
6767

68-
read_buf, write_buf = (block_info.read_bufs[0], block_info.write_bufs[0])
68+
read_buf, write_buf = (block_info.read_bufs(sch)[0], block_info.write_bufs(sch)[0])
6969
lps = block_info.get_loops()
7070
lpv_read, lpv_write = (
7171
read_buf.assoc_lps[-1],

python/tvm/dlight/adreno/utils.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -83,13 +83,6 @@ def schedule_default(sch, blocks: List[tir.schedule.BlockRV] = None):
8383
return ret
8484

8585

86-
def schedule_storage_annotate(sch: tir.Schedule, func=get_texture_storage):
87-
# Check the Write Buffer isn't one of input Params and is Texturizable...
88-
from .fallback import Fallback
89-
90-
return Fallback.schedule_annotate_storage(sch)
91-
92-
9386
def schedule_fallback(sch, blk):
9487
from .fallback import Fallback
9588

python/tvm/dlight/analysis/common_analysis.py

Lines changed: 23 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,12 @@ def __repr__(self) -> str:
6464

6565

6666
get_blockrealize = get_global_func("tir.schedule.GetBlockRealize")
67+
# BufferIndex Types
68+
Index = namedtuple("Index", ["sub"]) # c
69+
RemIndex = namedtuple("RemIndex", ["sub", "div"]) # c%len
70+
DivIndex = namedtuple("DivIndex", ["sub", "div"]) # c//len
71+
MergeIndex = namedtuple("MulIndex", ["dom", "mul", "sub"]) # co*len + cb
72+
BufIndex = List[Union[Index, RemIndex, DivIndex, MergeIndex, None]]
6773

6874

6975
# TODO: Shift Vlen Calculation here...
@@ -74,13 +80,6 @@ class BufferInfo:
7480
assoc_lps: List[Union[tir.schedule.LoopRV, None]]
7581
assoc_lps_info: List[Union[tir.For, None]]
7682

77-
# BufferIndex Types
78-
Index = namedtuple("Index", ["sub"]) # c
79-
RemIndex = namedtuple("RemIndex", ["sub", "div"]) # c%len
80-
DivIndex = namedtuple("DivIndex", ["sub", "div"]) # c//len
81-
MergeIndex = namedtuple("MulIndex", ["dom", "mul", "sub"]) # co*len + cb
82-
BufIndex = List[Union[Index, RemIndex, DivIndex, MergeIndex, None]]
83-
8483
def __init__(
8584
self,
8685
sch: tir.Schedule,
@@ -172,8 +171,6 @@ class BlockInfo:
172171
iters: List[IterInfo]
173172
block_rv: tir.schedule.BlockRV
174173
_reduction_block: bool
175-
read_bufs: List[BufferInfo]
176-
write_bufs: List[BufferInfo]
177174

178175
def __init__(
179176
self,
@@ -192,6 +189,16 @@ def dom(self) -> List[Union[int, tir.PrimExpr]]:
192189
"""The iteration domain of the block."""
193190
return [i.dom for i in self.iters]
194191

192+
def read_bufs(self, sch: tir.Schedule) -> List[BufferInfo]:
193+
block_stmt = sch.get(self.block_rv)
194+
lps = sch.get_loops(self.block_rv)
195+
return [BufferInfo(sch, self.block_rv, buf, lps) for buf in block_stmt.reads]
196+
197+
def write_bufs(self, sch: tir.Schedule) -> List[BufferInfo]:
198+
block_stmt = sch.get(self.block_rv)
199+
lps = sch.get_loops(self.block_rv)
200+
return [BufferInfo(sch, self.block_rv, buf, lps) for buf in block_stmt.writes]
201+
195202
def dom_kind(self) -> str:
196203
"""The iteration domain kind of the block, for example, SSSS, SSSR."""
197204
return "".join(i.kind for i in self.iters)
@@ -216,7 +223,7 @@ def _check_unit_var_range(dom: ir.Range, var: tir.Var) -> bool:
216223
if len(r_region) != len(w_region):
217224
return False
218225
for var, r_dom, w_dom in zip(block.iter_vars, r_region, w_region):
219-
if not _check_unit_var_range(var, r_dom) or not _check_unit_var_range(var, w_dom):
226+
if not _check_unit_var_range(r_dom, var) or not _check_unit_var_range(w_dom, var):
220227
return False
221228
return True
222229

@@ -230,31 +237,23 @@ def is_reduction(self) -> bool:
230237

231238
def is_layout_transform(self, sch: tir.Schedule) -> bool:
232239
"""Whether the Block can be considered having a Layout Transform Pattern"""
233-
block_stmt = sch.get(self.block_rv)
234-
lps = sch.get_loops(block_rv)
235-
read_bufs = [BufferInfo(sch, self.block_rv, buf, lps) for buf in block_stmt.reads]
236-
write_bufs = [BufferInfo(sch, self.block_rv, buf, lps) for buf in block_stmt.writes]
237240
return (
238241
all(k == "S" for k in self.dom_kind())
239-
and len(write_bufs) == 1
240-
and len(read_bufs) == 1
242+
and len(self.write_bufs(sch)) == 1
243+
and len(self.read_bufs(sch)) == 1
241244
and not self.is_elementwise(sch)
242245
and not get_global_func("tir.schedule.HasIfThenElse")(sch.get(self.block_rv))
243246
)
244247

245248
def is_data_pad(self, sch: tir.Schedule) -> bool:
246249
"""Whether the Block can be considered having a data pad pattern"""
247-
block_stmt = sch.get(self.block_rv)
248-
lps = sch.get_loops(block_rv)
249-
read_bufs = [BufferInfo(sch, self.block_rv, buf, lps) for buf in block_stmt.reads]
250-
write_bufs = [BufferInfo(sch, self.block_rv, buf, lps) for buf in block_stmt.writes]
251250
return (
252251
all(k == "S" for k in self.dom_kind())
253-
and len(write_bufs) == 1
254-
and len(read_bufs) == 1
252+
and len(self.write_bufs(sch)) == 1
253+
and len(self.read_bufs(sch)) == 1
255254
and not self.is_elementwise(sch)
256-
and len(self.write_bufs[0].buf_region.region)
257-
== len(self.read_bufs[0].buf_region.region)
255+
and len(self.write_bufs(sch)[0].buf_region.region)
256+
== len(self.read_bufs(sch)[0].buf_region.region)
258257
and get_global_func("tir.schedule.HasIfThenElse")(sch.get(self.block_rv))
259258
)
260259

src/relax/backend/vm/codegen_vm.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -215,15 +215,15 @@ class CodeGenVM : public ExprFunctor<Instruction::Arg(const Expr&)> {
215215
}
216216

217217
Instruction::Arg VisitExpr_(const ConstantNode* op) final {
218+
auto arg = builder_->ConvertConstant(op->data);
219+
218220
if (auto tsinfo = op->struct_info_.as<TensorStructInfoNode>()) {
219221
if (tsinfo->vdevice.defined()) {
220222
VDevice vdev = tsinfo->vdevice.value();
221-
runtime::Tensor param = op->data;
222-
param.SetScope(vdev->memory_scope);
223+
builder_->SaveMemoryScope(arg, vdev->memory_scope);
223224
}
224225
}
225-
226-
return builder_->ConvertConstant(op->data);
226+
return arg;
227227
}
228228

229229
Instruction::Arg VisitExpr_(const ShapeExprNode* op) final {

0 commit comments

Comments
 (0)