@@ -411,40 +411,199 @@ using ScalarType = exec_aten::ScalarType;
411
411
// Utility functions for checking tensor attributes
412
412
//
413
413
414
+ inline bool tensor_can_cast_to (
415
+ exec_aten::Tensor a,
416
+ exec_aten::ScalarType dtype) {
417
+ ET_LOG_MSG_AND_RETURN_IF_FALSE (
418
+ torch::executor::canCast (a.scalar_type (), dtype),
419
+ " Tensor of dtype %s cannot cast to dtype %s" ,
420
+ torch::executor::toString (a.scalar_type ()),
421
+ torch::executor::toString (dtype));
422
+
423
+ return true ;
424
+ }
425
+
426
+ inline bool tensor_is_bool_type (exec_aten::Tensor t) {
427
+ ET_LOG_MSG_AND_RETURN_IF_FALSE (
428
+ t.scalar_type () == exec_aten::ScalarType::Bool,
429
+ " Expected to find bool type, but tensor has type %s" ,
430
+ torch::executor::toString (t.scalar_type ()));
431
+
432
+ return true ;
433
+ }
434
+
435
+ inline bool tensor_is_integral_type (
436
+ exec_aten::Tensor t,
437
+ bool includeBool = false ) {
438
+ ET_LOG_MSG_AND_RETURN_IF_FALSE (
439
+ torch::executor::isIntegralType (t.scalar_type (), includeBool),
440
+ " Expected to find a integral type, but tensor has type %s" ,
441
+ torch::executor::toString (t.scalar_type ()));
442
+
443
+ return true ;
444
+ }
445
+
446
+ inline bool tensor_is_floating_type (exec_aten::Tensor t) {
447
+ ET_LOG_MSG_AND_RETURN_IF_FALSE (
448
+ torch::executor::isFloatingType (t.scalar_type ()),
449
+ " Expected to find a floating type, but tensor has type %s" ,
450
+ torch::executor::toString (t.scalar_type ()));
451
+
452
+ return true ;
453
+ }
454
+
455
+ inline bool tensor_is_complex_type (exec_aten::Tensor t) {
456
+ ET_LOG_MSG_AND_RETURN_IF_FALSE (
457
+ torch::executor::isComplexType (t.scalar_type ()),
458
+ " Expected to find a complex type, but tensor has type %s" ,
459
+ torch::executor::toString (t.scalar_type ()));
460
+
461
+ return true ;
462
+ }
463
+
464
+ inline bool tensor_is_bits_type (exec_aten::Tensor t) {
465
+ ET_LOG_MSG_AND_RETURN_IF_FALSE (
466
+ torch::executor::isBitsType (t.scalar_type ()),
467
+ " Expected to find a bits type, but tensor has type %s" ,
468
+ torch::executor::toString (t.scalar_type ()));
469
+
470
+ return true ;
471
+ }
472
+
414
473
inline bool tensors_have_same_dtype (exec_aten::Tensor a, exec_aten::Tensor b) {
415
- return a.scalar_type () == b.scalar_type ();
474
+ ET_LOG_MSG_AND_RETURN_IF_FALSE (
475
+ a.scalar_type () == b.scalar_type (),
476
+ ET_TENSOR_CHECK_PREFIX__ " : dtype={%s, %s}" ,
477
+ torch::executor::toString (a.scalar_type ()),
478
+ torch::executor::toString (b.scalar_type ()));
479
+ return true ;
416
480
}
417
481
418
482
inline bool tensors_have_same_dtype (
419
483
exec_aten::Tensor a,
420
484
exec_aten::Tensor b,
421
485
exec_aten::Tensor c) {
422
- return a.scalar_type () == b.scalar_type () &&
423
- b.scalar_type () == c.scalar_type ();
486
+ ET_LOG_MSG_AND_RETURN_IF_FALSE (
487
+ a.scalar_type () == b.scalar_type () && b.scalar_type () == c.scalar_type (),
488
+ ET_TENSOR_CHECK_PREFIX__ " : dtype={%s, %s, %s}" ,
489
+ torch::executor::toString (a.scalar_type ()),
490
+ torch::executor::toString (b.scalar_type ()),
491
+ torch::executor::toString (c.scalar_type ()));
492
+ return true ;
493
+ }
494
+
495
+ inline bool tensor_is_rank (exec_aten::Tensor t, size_t rank) {
496
+ ET_LOG_MSG_AND_RETURN_IF_FALSE (
497
+ t.dim () == rank,
498
+ " Expected tensor.dim() to be %zu, but got %zu" ,
499
+ static_cast <size_t >(rank),
500
+ static_cast <size_t >(t.dim ()));
501
+
502
+ return true ;
503
+ }
504
+
505
+ inline bool tensor_has_dim (exec_aten::Tensor t, int64_t d) {
506
+ ET_LOG_MSG_AND_RETURN_IF_FALSE (
507
+ d > 0 ? d < t.dim () : t.dim () + d >= 0 ,
508
+ " %zu-dim tensor does not have dim at index %zu" ,
509
+ static_cast <size_t >(t.dim ()),
510
+ static_cast <size_t >(d));
511
+
512
+ return true ;
513
+ }
514
+
515
+ inline bool tensors_have_same_size_at_dims (
516
+ exec_aten::Tensor a,
517
+ size_t dim_a,
518
+ exec_aten::Tensor b,
519
+ size_t dim_b) {
520
+ ET_LOG_MSG_AND_RETURN_IF_FALSE (
521
+ dim_a < a.dim (),
522
+ " Cannot retrieve dim %zu from tensor with dim %zu" ,
523
+ static_cast <size_t >(dim_a),
524
+ static_cast <size_t >(a.dim ()));
525
+ ET_LOG_MSG_AND_RETURN_IF_FALSE (
526
+ dim_b < b.dim (),
527
+ " Cannot retrieve dim %zu from tensor with dim %zu" ,
528
+ static_cast <size_t >(dim_b),
529
+ static_cast <size_t >(b.dim ()));
530
+ ET_LOG_MSG_AND_RETURN_IF_FALSE (
531
+ a.size (dim_a) == b.size (dim_b),
532
+ ET_TENSOR_CHECK_PREFIX__
533
+ " : a.size(%zu) = %zu does not match b.size(%zu) = %zu" ,
534
+ static_cast <size_t >(dim_a),
535
+ static_cast <size_t >(a.size (dim_a)),
536
+ static_cast <size_t >(dim_b),
537
+ static_cast <size_t >(b.size (dim_b)));
538
+
539
+ return true ;
424
540
}
425
541
426
542
inline bool tensors_have_same_shape (exec_aten::Tensor a, exec_aten::Tensor b) {
427
- if (a.numel () != b.numel ()) {
428
- return false ;
429
- }
430
- if (a.numel () == 1 ) {
543
+ if (a.numel () == 1 && b.numel () == 1 ) {
431
544
// PyTorch operators treat all scalar tensors as the same shape even if
432
545
// they have different dims.
433
546
return true ;
434
547
}
435
- // Does a length comparison (ensuring dims are equal) and element-by-element
436
- // comparison (ensuring sizes are equal).
437
- if (a.sizes () != b.sizes ()) {
548
+ if (!(a.sizes () == b.sizes () && a.numel () == b.numel ())) {
549
+ ET_LOG (
550
+ Error,
551
+ ET_TENSOR_CHECK_PREFIX__ " : numel=(%zu, %zu), dim=(%zu, %zu)" ,
552
+ static_cast <size_t >(a.numel ()),
553
+ static_cast <size_t >(b.numel ()),
554
+ static_cast <size_t >(a.dim ()),
555
+ static_cast <size_t >(b.dim ()));
556
+ for (size_t d = 0 ; d < ET_MIN2 (a.dim (), b.dim ()); ++d) {
557
+ ET_LOG (
558
+ Error,
559
+ " size(%zu): (%zu, %zu)" ,
560
+ static_cast <size_t >(d),
561
+ static_cast <size_t >(a.size (d)),
562
+ static_cast <size_t >(b.size (d)));
563
+ }
564
+
438
565
return false ;
439
566
}
567
+
440
568
return true ;
441
569
}
442
570
443
571
inline bool tensors_have_same_shape (
444
572
exec_aten::Tensor a,
445
573
exec_aten::Tensor b,
446
574
exec_aten::Tensor c) {
447
- return tensors_have_same_shape (a, b) && tensors_have_same_shape (b, c);
575
+ if (a.numel () == 1 && b.numel () == 1 && c.numel () == 1 ) {
576
+ // PyTorch operators treat all scalar tensors as the same shape even if
577
+ // they have different dims.
578
+ return true ;
579
+ }
580
+ bool cond1 = (a.sizes () == b.sizes ()) && (a.numel () == b.numel ());
581
+ bool cond2 = (b.sizes () == c.sizes ()) && (b.numel () == c.numel ());
582
+
583
+ if (!(cond1 && cond2)) {
584
+ ET_LOG (
585
+ Error,
586
+ ET_TENSOR_CHECK_PREFIX__ " : numel=(%zu, %zu, %zu), dim=(%zu, %zu, %zu)" ,
587
+ static_cast <size_t >(a.numel ()),
588
+ static_cast <size_t >(b.numel ()),
589
+ static_cast <size_t >(c.numel ()),
590
+ static_cast <size_t >(a.dim ()),
591
+ static_cast <size_t >(b.dim ()),
592
+ static_cast <size_t >(c.dim ()));
593
+ for (size_t d = 0 ; d < ET_MIN3 (a.dim (), b.dim (), c.dim ()); ++d) {
594
+ ET_LOG (
595
+ Error,
596
+ " size(%zu): (%zu, %zu, %zu)" ,
597
+ static_cast <size_t >(d),
598
+ static_cast <size_t >(a.size (d)),
599
+ static_cast <size_t >(b.size (d)),
600
+ static_cast <size_t >(c.size (d)));
601
+ }
602
+
603
+ return false ;
604
+ }
605
+
606
+ return true ;
448
607
}
449
608
450
609
inline bool tensors_have_same_shape_and_dtype (
@@ -463,14 +622,50 @@ inline bool tensors_have_same_shape_and_dtype(
463
622
inline bool tensors_have_same_strides (
464
623
exec_aten::Tensor a,
465
624
exec_aten::Tensor b) {
466
- return a.strides () == b.strides ();
625
+ if (a.strides () != b.strides ()) {
626
+ ET_LOG (
627
+ Error,
628
+ ET_TENSOR_CHECK_PREFIX__ " : dim=(%zu, %zu)" ,
629
+ static_cast <size_t >(a.dim ()),
630
+ static_cast <size_t >(b.dim ()));
631
+ for (size_t d = 0 ; d < ET_MIN2 (a.dim (), b.dim ()); ++d) {
632
+ ET_LOG (
633
+ Error,
634
+ " stride(%zu): (%zu, %zu)" ,
635
+ static_cast <size_t >(d),
636
+ static_cast <size_t >(a.strides ()[d]),
637
+ static_cast <size_t >(b.strides ()[d]));
638
+ }
639
+
640
+ return false ;
641
+ }
642
+ return true ;
467
643
}
468
644
469
645
inline bool tensors_have_same_strides (
470
646
exec_aten::Tensor a,
471
647
exec_aten::Tensor b,
472
648
exec_aten::Tensor c) {
473
- return a.strides () == b.strides () && b.strides () == c.strides ();
649
+ if (!(a.strides () == b.strides () && b.strides () == c.strides ())) {
650
+ ET_LOG (
651
+ Error,
652
+ ET_TENSOR_CHECK_PREFIX__ " : dim=(%zu, %zu, %zu)" ,
653
+ static_cast <size_t >(a.dim ()),
654
+ static_cast <size_t >(b.dim ()),
655
+ static_cast <size_t >(c.dim ()));
656
+ for (size_t d = 0 ; d < ET_MIN3 (a.dim (), b.dim (), c.dim ()); ++d) {
657
+ ET_LOG (
658
+ Error,
659
+ " stride(%zu): (%zu, %zu, %zu)" ,
660
+ static_cast <size_t >(d),
661
+ static_cast <size_t >(a.strides ()[d]),
662
+ static_cast <size_t >(b.strides ()[d]),
663
+ static_cast <size_t >(c.strides ()[d]));
664
+ }
665
+
666
+ return false ;
667
+ }
668
+ return true ;
474
669
}
475
670
476
671
inline bool tensor_is_contiguous (exec_aten::Tensor t) {
@@ -480,13 +675,21 @@ inline bool tensor_is_contiguous(exec_aten::Tensor t) {
480
675
if (strides.size () == 0 ) {
481
676
return true ;
482
677
}
483
- if (strides[strides.size () - 1 ] != 1 ) {
484
- return false ;
485
- }
486
- for (auto i = strides.size () - 1 ; i > 0 ; --i) {
487
- if (strides[i - 1 ] != strides[i] * sizes[i]) {
488
- return false ;
489
- }
678
+ ET_LOG_MSG_AND_RETURN_IF_FALSE (
679
+ strides[strides.size () - 1 ] == 1 ,
680
+ " Tensor is not contiguous; the stride of the last dimension must be 1, "
681
+ " but got %zu" ,
682
+ static_cast <size_t >(strides[strides.size () - 1 ]));
683
+ for (int i = strides.size () - 1 ; i > 0 ; --i) {
684
+ ET_LOG_MSG_AND_RETURN_IF_FALSE (
685
+ strides[i - 1 ] == strides[i] * sizes[i],
686
+ " Tensor is not contiguous; the stride of dim %zu should be equal to "
687
+ " strides[%zu] * sizes[%zu] = %zu, but found %zu" ,
688
+ static_cast <size_t >(i - 1 ),
689
+ static_cast <size_t >(i),
690
+ static_cast <size_t >(i),
691
+ static_cast <size_t >(strides[i] * sizes[i]),
692
+ static_cast <size_t >(strides[i - 1 ]));
490
693
}
491
694
return true ;
492
695
}
0 commit comments