@@ -86,6 +86,60 @@ namespace refactor::python_ffi {
8686 os.write (ptr, size);
8787 }
8888
89+ static void writeText (std::ofstream os, char const *ptr, size_t size,
90+ DataType dataType, computation::Shape const &shape) {
91+ if (shape.empty ()) {
92+ os << dataType.name () << " <>" << std::endl;
93+ return ;
94+ } else {
95+ auto iter = shape.begin ();
96+ os << dataType.name () << ' <' << *iter++;
97+ while (iter != shape.end ()) { os << ' x' << *iter++; }
98+ os << ' >' << std::endl;
99+ };
100+
101+ #define CASE (T ) \
102+ case DataType::T: { \
103+ using T_ = primitive<DataType::T>::type; \
104+ auto ptr_ = reinterpret_cast <T_ const *>(ptr); \
105+ for (auto i : range0_ (size / sizeof (T_))) { \
106+ os << ptr_[i] << ' \t ' ; \
107+ } \
108+ } break
109+
110+ switch (dataType) {
111+ case DataType::U8: {
112+ auto ptr_ = reinterpret_cast <uint8_t const *>(ptr);
113+ for (auto i : range0_ (size)) {
114+ os << static_cast <int >(ptr_[i]) << ' \t ' ;
115+ }
116+ } break ;
117+ case DataType::I8: {
118+ auto ptr_ = reinterpret_cast <int8_t const *>(ptr);
119+ for (auto i : range0_ (size)) {
120+ os << static_cast <int >(ptr_[i]) << ' \t ' ;
121+ }
122+ } break ;
123+ case DataType::Bool: {
124+ auto ptr_ = reinterpret_cast <bool const *>(ptr);
125+ for (auto i : range0_ (size)) {
126+ os << (ptr_[i] ? " true " : " false" ) << ' \t ' ;
127+ }
128+ } break ;
129+ CASE (F32);
130+ CASE (U16);
131+ CASE (I16);
132+ CASE (I32);
133+ CASE (I64);
134+ CASE (F64);
135+ CASE (U32);
136+ CASE (U64);
137+ default :
138+ UNREACHABLE ();
139+ break ;
140+ }
141+ }
142+
89143 static void writeNpy (std::ofstream os, char const *ptr, size_t size,
90144 DataType dataType, computation::Shape const &shape) {
91145 std::stringstream ss;
@@ -136,7 +190,6 @@ namespace refactor::python_ffi {
136190 fs::create_directories (path);
137191 ASSERT (fs::is_directory (path), " Failed to create \" {}\" " , path.c_str ());
138192
139- auto const npy = format == " npy" ;
140193 size_t dataIdx = 0 ;
141194
142195 auto const &graph = _graph.internal ().contiguous ();
@@ -164,9 +217,12 @@ namespace refactor::python_ffi {
164217 auto file = path / fmt::format (" data{:06}.{}" , dataIdx++, format);
165218 fs::remove (file);
166219 std::ofstream os (file, std::ios::binary);
167- if (npy) {
220+ if (format == " npy" ) {
168221 writeNpy (std::move (os), buffer.data (), size,
169222 edge.tensor ->dataType , edge.tensor ->shape );
223+ } else if (format == " text" ) {
224+ writeText (std::move (os), buffer.data (), size,
225+ edge.tensor ->dataType , edge.tensor ->shape );
170226 } else {
171227 writeBin (std::move (os), buffer.data (), size);
172228 }
0 commit comments