@@ -441,31 +441,35 @@ namespace ctranslate2 {
441441 }
442442
443443 std::ostream& operator <<(std::ostream& os, const StorageView& storage) {
444+ // Create a printable copy of the storage
444445 StorageView printable (storage.dtype ());
445446 printable.copy_from (storage);
447+
448+ // Check the data type and print accordingly
446449 TYPE_DISPATCH (
447450 printable.dtype (),
448451 const auto * values = printable.data <T>();
449- if ( printable.size () <= PRINT_MAX_VALUES) {
450- for ( dim_t i = 0 ; i < printable. size (); ++i) {
451- os << ' ' ;
452- print_value (os, values[i]);
453- }
454- }
455- else {
456- for ( dim_t i = 0 ; i < PRINT_MAX_VALUES / 2 ; ++i) {
457- os << ' ' ;
458- print_value (os, values[i]);
459- }
460- os << " ... " ;
461- for ( dim_t i = printable. size () - (PRINT_MAX_VALUES / 2 ); i < printable. size (); ++i) {
462- os << ' ' ;
463- print_value (os, values[i] );
464- }
452+ const auto & shape = printable.shape ();
453+
454+ // Print tensor contents based on dimensionality
455+ if (shape. empty ()) { // Scalar case
456+ os << " Data (Scalar): " << values[ 0 ] << std::endl;
457+ } else {
458+ os << " Data ( " << shape. size () << " D " ;
459+ if (shape. size () == 1 )
460+ os << " Vector " ;
461+ else if (shape. size () == 2 )
462+ os << " Matrix " ;
463+ else
464+ os << " Tensor " ;
465+ os << " ): " << std::endl ;
466+ printable. print_tensor (os, values, shape, 0 , 0 , 0 );
467+ os << std::endl;
465468 }
466- os << std::endl);
467- os << " [" << device_to_str (storage.device (), storage.device_index ())
468- << " " << dtype_name (storage.dtype ()) << " storage viewed as " ;
469+ );
470+
471+ os << " [device:" << device_to_str (storage.device (), storage.device_index ())
472+ << " , dtype:" << dtype_name (storage.dtype ()) << " , storage viewed as " ;
469473 if (storage.is_scalar ())
470474 os << " scalar" ;
471475 else {
@@ -479,6 +483,34 @@ namespace ctranslate2 {
479483 return os;
480484 }
481485
486+ template <typename T>
487+ void StorageView::print_tensor (std::ostream& os, const T* data, const std::vector<dim_t >& shape, size_t dim, size_t offset, int indent) const {
488+ std::string indentation (indent, ' ' );
489+
490+ os << indentation << " [" ;
491+ bool is_last_dim = (dim == shape.size () - 1 );
492+ for (dim_t i = 0 ; i < shape[dim]; ++i) {
493+ if (i > 0 ) {
494+ os << " , " ;
495+ if (!is_last_dim) {
496+ os << " \n " << std::string (indent, ' ' );
497+ }
498+ }
499+
500+ if (i == PRINT_MAX_VALUES / 2 && shape[dim] > PRINT_MAX_VALUES) {
501+ os << " ..." ;
502+ i = shape[dim] - PRINT_MAX_VALUES / 2 - 1 ; // Skip to the last part
503+ } else {
504+ if (is_last_dim) {
505+ os << +data[offset + i];
506+ } else {
507+ print_tensor (os, data, shape, dim + 1 , offset + i * shape[dim + 1 ], indent);
508+ }
509+ }
510+ }
511+ os << " ]" ;
512+ }
513+
482514#define DECLARE_IMPL (T ) \
483515 template \
484516 StorageView::StorageView (Shape shape, T init, Device device); \
0 commit comments