Skip to content

Commit ec39fd7

Browse files
committed
feat(python_ffi): 添加一个字符型的 trace 格式
Signed-off-by: YdrMaster <[email protected]>
1 parent a163eee commit ec39fd7

File tree

1 file changed

+58
-2
lines changed

1 file changed

+58
-2
lines changed

src/09python_ffi/src/executor.cc

Lines changed: 58 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,60 @@ namespace refactor::python_ffi {
8686
os.write(ptr, size);
8787
}
8888

89+
static void writeText(std::ofstream os, char const *ptr, size_t size,
90+
DataType dataType, computation::Shape const &shape) {
91+
if (shape.empty()) {
92+
os << dataType.name() << "<>" << std::endl;
93+
return;
94+
} else {
95+
auto iter = shape.begin();
96+
os << dataType.name() << '<' << *iter++;
97+
while (iter != shape.end()) { os << 'x' << *iter++; }
98+
os << '>' << std::endl;
99+
};
100+
101+
#define CASE(T) \
102+
case DataType::T: { \
103+
using T_ = primitive<DataType::T>::type; \
104+
auto ptr_ = reinterpret_cast<T_ const *>(ptr); \
105+
for (auto i : range0_(size / sizeof(T_))) { \
106+
os << ptr_[i] << '\t'; \
107+
} \
108+
} break
109+
110+
switch (dataType) {
111+
case DataType::U8: {
112+
auto ptr_ = reinterpret_cast<uint8_t const *>(ptr);
113+
for (auto i : range0_(size)) {
114+
os << static_cast<int>(ptr_[i]) << '\t';
115+
}
116+
} break;
117+
case DataType::I8: {
118+
auto ptr_ = reinterpret_cast<int8_t const *>(ptr);
119+
for (auto i : range0_(size)) {
120+
os << static_cast<int>(ptr_[i]) << '\t';
121+
}
122+
} break;
123+
case DataType::Bool: {
124+
auto ptr_ = reinterpret_cast<bool const *>(ptr);
125+
for (auto i : range0_(size)) {
126+
os << (ptr_[i] ? "true " : "false") << '\t';
127+
}
128+
} break;
129+
CASE(F32);
130+
CASE(U16);
131+
CASE(I16);
132+
CASE(I32);
133+
CASE(I64);
134+
CASE(F64);
135+
CASE(U32);
136+
CASE(U64);
137+
default:
138+
UNREACHABLE();
139+
break;
140+
}
141+
}
142+
89143
static void writeNpy(std::ofstream os, char const *ptr, size_t size,
90144
DataType dataType, computation::Shape const &shape) {
91145
std::stringstream ss;
@@ -136,7 +190,6 @@ namespace refactor::python_ffi {
136190
fs::create_directories(path);
137191
ASSERT(fs::is_directory(path), "Failed to create \"{}\"", path.c_str());
138192

139-
auto const npy = format == "npy";
140193
size_t dataIdx = 0;
141194

142195
auto const &graph = _graph.internal().contiguous();
@@ -164,9 +217,12 @@ namespace refactor::python_ffi {
164217
auto file = path / fmt::format("data{:06}.{}", dataIdx++, format);
165218
fs::remove(file);
166219
std::ofstream os(file, std::ios::binary);
167-
if (npy) {
220+
if (format == "npy") {
168221
writeNpy(std::move(os), buffer.data(), size,
169222
edge.tensor->dataType, edge.tensor->shape);
223+
} else if (format == "text") {
224+
writeText(std::move(os), buffer.data(), size,
225+
edge.tensor->dataType, edge.tensor->shape);
170226
} else {
171227
writeBin(std::move(os), buffer.data(), size);
172228
}

0 commit comments

Comments
 (0)