Skip to content

Commit 9e068ee

Browse files
committed
[*] improve log message for storage view content
1 parent 59c7dda commit 9e068ee

File tree

2 files changed

+54
-19
lines changed

2 files changed

+54
-19
lines changed

include/ctranslate2/storage_view.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,9 @@ namespace ctranslate2 {
238238

239239
friend std::ostream& operator<<(std::ostream& os, const StorageView& storage);
240240

241+
template <typename T>
242+
void print_tensor(std::ostream& os, const T* data, const std::vector<dim_t>& shape, size_t dim, size_t offset, int indent) const;
243+
241244
protected:
242245
DataType _dtype = DataType::FLOAT32;
243246
Device _device = Device::CPU;

src/storage_view.cc

Lines changed: 51 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -441,31 +441,35 @@ namespace ctranslate2 {
441441
}
442442

443443
std::ostream& operator<<(std::ostream& os, const StorageView& storage) {
444+
// Create a printable copy of the storage
444445
StorageView printable(storage.dtype());
445446
printable.copy_from(storage);
447+
448+
// Check the data type and print accordingly
446449
TYPE_DISPATCH(
447450
printable.dtype(),
448451
const auto* values = printable.data<T>();
449-
if (printable.size() <= PRINT_MAX_VALUES) {
450-
for (dim_t i = 0; i < printable.size(); ++i) {
451-
os << ' ';
452-
print_value(os, values[i]);
453-
}
454-
}
455-
else {
456-
for (dim_t i = 0; i < PRINT_MAX_VALUES / 2; ++i) {
457-
os << ' ';
458-
print_value(os, values[i]);
459-
}
460-
os << " ...";
461-
for (dim_t i = printable.size() - (PRINT_MAX_VALUES / 2); i < printable.size(); ++i) {
462-
os << ' ';
463-
print_value(os, values[i]);
464-
}
452+
const auto& shape = printable.shape();
453+
454+
// Print tensor contents based on dimensionality
455+
if (shape.empty()) { // Scalar case
456+
os << "Data (Scalar): " << values[0] << std::endl;
457+
} else {
458+
os << "Data (" << shape.size() << "D ";
459+
if (shape.size() == 1)
460+
os << "Vector";
461+
else if (shape.size() == 2)
462+
os << "Matrix";
463+
else
464+
os << "Tensor";
465+
os << "):" << std::endl;
466+
printable.print_tensor(os, values, shape, 0, 0, 0);
467+
os << std::endl;
465468
}
466-
os << std::endl);
467-
os << "[" << device_to_str(storage.device(), storage.device_index())
468-
<< " " << dtype_name(storage.dtype()) << " storage viewed as ";
469+
);
470+
471+
os << "[device:" << device_to_str(storage.device(), storage.device_index())
472+
<< ", dtype:" << dtype_name(storage.dtype()) << ", storage viewed as ";
469473
if (storage.is_scalar())
470474
os << "scalar";
471475
else {
@@ -479,6 +483,34 @@ namespace ctranslate2 {
479483
return os;
480484
}
481485

486+
template <typename T>
487+
void StorageView::print_tensor(std::ostream& os, const T* data, const std::vector<dim_t>& shape, size_t dim, size_t offset, int indent) const {
488+
std::string indentation(indent, ' ');
489+
490+
os << indentation << "[";
491+
bool is_last_dim = (dim == shape.size() - 1);
492+
for (dim_t i = 0; i < shape[dim]; ++i) {
493+
if (i > 0) {
494+
os << ", ";
495+
if (!is_last_dim) {
496+
os << "\n" << std::string(indent, ' ');
497+
}
498+
}
499+
500+
if (i == PRINT_MAX_VALUES / 2 && shape[dim] > PRINT_MAX_VALUES) {
501+
os << "...";
502+
i = shape[dim] - PRINT_MAX_VALUES / 2 - 1; // Skip to the last part
503+
} else {
504+
if (is_last_dim) {
505+
os << +data[offset + i];
506+
} else {
507+
print_tensor(os, data, shape, dim + 1, offset + i * shape[dim + 1], indent);
508+
}
509+
}
510+
}
511+
os << "]";
512+
}
513+
482514
#define DECLARE_IMPL(T) \
483515
template \
484516
StorageView::StorageView(Shape shape, T init, Device device); \

0 commit comments

Comments
 (0)