@@ -509,11 +509,11 @@ TEST_F(AOTITorchEmptyStridedTest, ZeroElementTensor) {
509509 EXPECT_EQ (sizes_ptr[2 ], 3 );
510510}
511511
512- // Test different data types (only float32 is currently supported )
512+ // Test different data types (currently we support bf16, fp32 and int32 )
513513TEST_F (AOTITorchEmptyStridedTest, DifferentDataTypes) {
514514 std::vector<int64_t > sizes = {2 , 3 };
515515
516- // Test float32 (dtype 6) - currently the only supported type
516+ // Test float32 (dtype 6) - one of the supported types
517517 Tensor* tensor_float32;
518518 AOTITorchError error = aoti_torch_empty_strided (
519519 sizes.size (),
@@ -527,7 +527,7 @@ TEST_F(AOTITorchEmptyStridedTest, DifferentDataTypes) {
527527 EXPECT_EQ (error, Error::Ok);
528528 EXPECT_NE (tensor_float32, nullptr );
529529
530- // Test unsupported data types should return error
530+ // Test int32 (dtype 3) - one of the supported types
531531 Tensor* tensor_int32;
532532 error = aoti_torch_empty_strided (
533533 sizes.size (),
@@ -538,7 +538,8 @@ TEST_F(AOTITorchEmptyStridedTest, DifferentDataTypes) {
538538 0 , // device index
539539 &tensor_int32);
540540
541- EXPECT_EQ (error, Error::InvalidArgument); // Should fail for unsupported dtype
541+ EXPECT_EQ (error, Error::Ok);
542+ EXPECT_NE (tensor_int32, nullptr );
542543
543544 // Test another unsupported data type
544545 Tensor* tensor_float64;
@@ -586,3 +587,105 @@ TEST_F(AOTITorchEmptyStridedTest, MultiDimensionalTensors) {
586587 EXPECT_EQ (tensor_5d->size (3 ), 4 );
587588 EXPECT_EQ (tensor_5d->size (4 ), 5 );
588589}
590+
591+ // Test incontiguous tensor creation - transpose-like layout
592+ TEST_F (AOTITorchEmptyStridedTest, IncontiguousTransposeLayout) {
593+ // Create a tensor with transpose-like strides (column-major)
594+ // For a 3x4 tensor in column-major order, strides should be [1, 3]
595+ // This means each row step is 1, and each column step is 3
596+ std::vector<int64_t > sizes = {3 , 4 };
597+ std::vector<int64_t > strides = {1 , 3 }; // Column-major (incontiguous)
598+
599+ Tensor* tensor;
600+ AOTITorchError error = aoti_torch_empty_strided (
601+ sizes.size (),
602+ sizes.data (),
603+ strides.data (),
604+ static_cast <int32_t >(SupportedDTypes::FLOAT32),
605+ static_cast <int32_t >(SupportedDevices::CUDA),
606+ 0 , // device index
607+ &tensor);
608+
609+ EXPECT_EQ (error, Error::Ok);
610+ EXPECT_NE (tensor, nullptr );
611+
612+ // Verify tensor properties
613+ EXPECT_EQ (tensor->dim (), 2 );
614+ EXPECT_EQ (tensor->size (0 ), 3 );
615+ EXPECT_EQ (tensor->size (1 ), 4 );
616+
617+ // Verify the strides are what we specified
618+ int64_t * strides_ptr;
619+ EXPECT_EQ (aoti_torch_get_strides (tensor, &strides_ptr), Error::Ok);
620+ EXPECT_EQ (strides_ptr[0 ], 1 ); // Column-major stride for dimension 0
621+ EXPECT_EQ (strides_ptr[1 ], 3 ); // Column-major stride for dimension 1
622+
623+ // Verify that memory was allocated correctly for incontiguous layout
624+ // Storage size should be: stride[0] * (size[0] - 1) + stride[1] * (size[1] -
625+ // 1) + 1 = 1 * (3 - 1) + 3 * (4 - 1) + 1 = 1 * 2 + 3 * 3 + 1 = 2 + 9 + 1 = 12
626+ // elements Total bytes = 12 * 4 = 48 bytes (for float32)
627+ EXPECT_EQ (tensor->numel (), 12 ); // numel is still 3*4=12 for logical shape
628+
629+ // The tensor should be accessible and writable
630+ void * data_ptr = tensor->mutable_data_ptr ();
631+ EXPECT_NE (data_ptr, nullptr );
632+
633+ // Verify we can use CUDA to write to the memory
634+ std::vector<float > test_data (12 , 1 .0f );
635+ cudaError_t cuda_err = cudaMemcpy (
636+ data_ptr, test_data.data (), 12 * sizeof (float ), cudaMemcpyHostToDevice);
637+ EXPECT_EQ (cuda_err, cudaSuccess);
638+ }
639+
640+ // Test incontiguous tensor creation - expanded/broadcasted stride pattern
641+ TEST_F (AOTITorchEmptyStridedTest, IncontiguousExpandedStrides) {
642+ // Create a tensor with expanded strides (simulating broadcasting)
643+ // A 2x3x4 tensor where the first dimension has stride 0 (expanded)
644+ // This creates a tensor where the first dimension is "broadcasted"
645+ std::vector<int64_t > sizes = {2 , 3 , 4 };
646+ std::vector<int64_t > strides = {0 , 4 , 1 }; // First dimension has stride 0
647+
648+ Tensor* tensor;
649+ AOTITorchError error = aoti_torch_empty_strided (
650+ sizes.size (),
651+ sizes.data (),
652+ strides.data (),
653+ static_cast <int32_t >(SupportedDTypes::FLOAT32),
654+ static_cast <int32_t >(SupportedDevices::CUDA),
655+ 0 , // device index
656+ &tensor);
657+
658+ EXPECT_EQ (error, Error::Ok);
659+ EXPECT_NE (tensor, nullptr );
660+
661+ // Verify tensor properties
662+ EXPECT_EQ (tensor->dim (), 3 );
663+ EXPECT_EQ (tensor->size (0 ), 2 );
664+ EXPECT_EQ (tensor->size (1 ), 3 );
665+ EXPECT_EQ (tensor->size (2 ), 4 );
666+
667+ // Verify the strides are what we specified
668+ int64_t * strides_ptr;
669+ EXPECT_EQ (aoti_torch_get_strides (tensor, &strides_ptr), Error::Ok);
670+ EXPECT_EQ (strides_ptr[0 ], 0 ); // Expanded dimension stride
671+ EXPECT_EQ (strides_ptr[1 ], 4 );
672+ EXPECT_EQ (strides_ptr[2 ], 1 );
673+
674+ // Verify that memory was allocated correctly for this incontiguous layout
675+ // Storage size should be: stride[0] * (size[0] - 1) + stride[1] * (size[1] -
676+ // 1) + stride[2] * (size[2] - 1) + 1 = 0 * (2 - 1) + 4 * (3 - 1) + 1 * (4 -
677+ // 1) + 1 = 0 + 8 + 3 + 1 = 12 elements Note: numel() returns logical number
678+ // of elements (2*3*4=24), not storage size
679+ EXPECT_EQ (tensor->numel (), 24 ); // Logical numel is 2*3*4=24
680+
681+ // The tensor should be accessible and writable
682+ void * data_ptr = tensor->mutable_data_ptr ();
683+ EXPECT_NE (data_ptr, nullptr );
684+
685+ // Verify we can use CUDA to write to the allocated memory
686+ // We only need to allocate 12 elements (storage size), not 24
687+ std::vector<float > test_data (12 , 2 .0f );
688+ cudaError_t cuda_err = cudaMemcpy (
689+ data_ptr, test_data.data (), 12 * sizeof (float ), cudaMemcpyHostToDevice);
690+ EXPECT_EQ (cuda_err, cudaSuccess);
691+ }
0 commit comments