@@ -440,44 +440,77 @@ namespace ctranslate2 {
440440 return os;
441441 }
442442
443- std::ostream& operator <<(std::ostream& os, const StorageView& storage) {
444- StorageView printable (storage.dtype ());
445- printable.copy_from (storage);
446- TYPE_DISPATCH (
447- printable.dtype (),
448- const auto * values = printable.data <T>();
449- if (printable.size () <= PRINT_MAX_VALUES) {
450- for (dim_t i = 0 ; i < printable.size (); ++i) {
451- os << ' ' ;
452- print_value (os, values[i]);
443+ std::ostream& operator <<(std::ostream& os, const StorageView& storage) {
444+ // Create a printable copy of the storage
445+ StorageView printable (storage.dtype ());
446+ printable.copy_from (storage);
447+
448+ // Check the data type and print accordingly
449+ TYPE_DISPATCH (
450+ printable.dtype (),
451+ const auto * values = printable.data <T>();
452+ const auto & shape = printable.shape ();
453+
454+ // Print tensor contents based on dimensionality
455+ if (shape.empty ()) { // Scalar case
456+ os << " Data (Scalar): " << values[0 ] << std::endl;
457+ } else {
458+ os << " Data (" << shape.size () << " D " ;
459+ if (shape.size () == 1 )
460+ os << " Vector" ;
461+ else if (shape.size () == 2 )
462+ os << " Matrix" ;
463+ else
464+ os << " Tensor" ;
465+ os << " ):" << std::endl;
466+ printable.print_tensor (os, values, shape, 0 , 0 , 0 );
467+ os << std::endl;
468+ }
469+ );
470+
471+ // Print metadata
472+ os << " [device:" << device_to_str (storage.device (), storage.device_index ())
473+ << " , dtype:" << dtype_name (storage.dtype ()) << " , storage viewed as " ;
474+ if (storage.is_scalar ())
475+ os << " scalar" ;
476+ else {
477+ for (dim_t i = 0 ; i < storage.rank (); ++i) {
478+ if (i > 0 )
479+ os << ' x' ;
480+ os << storage.dim (i);
453481 }
454482 }
455- else {
456- for (dim_t i = 0 ; i < PRINT_MAX_VALUES / 2 ; ++i) {
457- os << ' ' ;
458- print_value (os, values[i]);
483+ os << ' ]' ;
484+ return os;
485+ }
486+
487+ template <typename T>
488+ void StorageView::print_tensor (std::ostream& os, const T* data, const std::vector<dim_t >& shape, size_t dim, size_t offset, int indent) const {
489+ std::string indentation (indent, ' ' );
490+
491+ os << indentation << " [" ;
492+ bool is_last_dim = (dim == shape.size () - 1 );
493+ for (dim_t i = 0 ; i < shape[dim]; ++i) {
494+ if (i > 0 ) {
495+ os << " , " ;
496+ if (!is_last_dim) {
497+ os << " \n " << std::string (indent, ' ' );
498+ }
459499 }
460- os << " ..." ;
461- for (dim_t i = printable.size () - (PRINT_MAX_VALUES / 2 ); i < printable.size (); ++i) {
462- os << ' ' ;
463- print_value (os, values[i]);
500+
501+ if (i == PRINT_MAX_VALUES / 2 ) {
502+ os << " ..." ;
503+ i = shape[dim] - PRINT_MAX_VALUES / 2 - 1 ; // Skip to the last part
504+ } else {
505+ if (is_last_dim) {
506+ os << data[offset + i];
507+ } else {
508+ print_tensor (os, data, shape, dim + 1 , offset + i * shape[dim + 1 ], indent);
509+ }
464510 }
465511 }
466- os << std::endl);
467- os << " [" << device_to_str (storage.device (), storage.device_index ())
468- << " " << dtype_name (storage.dtype ()) << " storage viewed as " ;
469- if (storage.is_scalar ())
470- os << " scalar" ;
471- else {
472- for (dim_t i = 0 ; i < storage.rank (); ++i) {
473- if (i > 0 )
474- os << ' x' ;
475- os << storage.dim (i);
476- }
512+ os << " ]" ;
477513 }
478- os << ' ]' ;
479- return os;
480- }
481514
482515#define DECLARE_IMPL (T ) \
483516 template \
0 commit comments