@@ -16,7 +16,7 @@ use burn::tensor::Tensor;
1616#[ allow( unused_imports) ]
1717use num_traits:: Float ;
1818
19- #[ derive( Config ) ]
19+ #[ derive( Config , Debug ) ]
2020pub struct GeGluConfig {
2121 /// The size of the input features.
2222 d_input : usize ,
@@ -44,7 +44,7 @@ impl<B: Backend> GeGlu<B> {
4444 }
4545}
4646
47- #[ derive( Config ) ]
47+ #[ derive( Config , Debug ) ]
4848pub struct FeedForwardConfig {
4949 /// The size of the input features.
5050 pub d_input : usize ,
@@ -90,7 +90,7 @@ impl<B: Backend> FeedForward<B> {
9090 }
9191}
9292
93- #[ derive( Config ) ]
93+ #[ derive( Config , Debug ) ]
9494pub struct CrossAttentionConfig {
9595 /// The number of channels in the query.
9696 d_query : usize ,
@@ -232,7 +232,7 @@ impl<B: Backend> CrossAttention<B> {
232232 }
233233}
234234
235- #[ derive( Config ) ]
235+ #[ derive( Config , Debug ) ]
236236pub struct BasicTransformerBlockConfig {
237237 d_model : usize ,
238238 d_context : Option < usize > ,
@@ -488,13 +488,13 @@ mod tests {
488488 use super :: * ;
489489 use crate :: TestBackend ;
490490 use burn:: module:: { Param , ParamId } ;
491- use burn:: tensor:: { Data , Shape } ;
491+ use burn:: tensor:: { Shape , TensorData , Tolerance } ;
492492
493493 #[ test]
494494 fn test_geglu_tensor_shape_3 ( ) {
495495 let device = Default :: default ( ) ;
496496 let weight = Tensor :: from_data (
497- Data :: from ( [
497+ TensorData :: from ( [
498498 [
499499 0.1221 , 2.0378 , -0.1171 , 1.3004 , -0.9630 , -0.3108 , -1.3376 , -1.0593 ,
500500 ] ,
@@ -505,7 +505,7 @@ mod tests {
505505 & device,
506506 ) ;
507507 let bias = Tensor :: from_data (
508- Data :: from ( [
508+ TensorData :: from ( [
509509 0.2867778149426027 ,
510510 0.6646517317105776 ,
511511 0.023946332404821136 ,
@@ -526,7 +526,7 @@ mod tests {
526526 } ;
527527
528528 let tensor: Tensor < TestBackend , 3 > = Tensor :: from_data (
529- Data :: from ( [
529+ TensorData :: from ( [
530530 [ [ 1. , 2. ] , [ 3. , 4. ] , [ 5. , 6. ] ] ,
531531 [ [ 7. , 8. ] , [ 9. , 10. ] , [ 11. , 12. ] ] ,
532532 ] ) ,
@@ -535,8 +535,8 @@ mod tests {
535535
536536 let output = geglu. forward ( tensor) ;
537537 assert_eq ! ( output. shape( ) , Shape :: from( [ 2 , 3 , 4 ] ) ) ;
538- output. to_data ( ) . assert_approx_eq (
539- & Data :: from ( [
538+ output. into_data ( ) . assert_approx_eq :: < f32 > (
539+ & TensorData :: from ( [
540540 [
541541 [ 4.2632e0 , -1.7927e-1 , -2.3216e-1 , -3.7916e-2 ] ,
542542 [ 1.3460e1 , -2.9266e-1 , -2.1707e-4 , -4.5595e-2 ] ,
@@ -548,22 +548,22 @@ mod tests {
548548 [ 1.0119e2 , -2.1943e-5 , -0.0000e0 , -0.0000e0 ] ,
549549 ] ,
550550 ] ) ,
551- 2 ,
551+ Tolerance :: rel_abs ( 1e-2 , 1e-2 ) ,
552552 ) ;
553553 }
554554
555555 #[ test]
556556 fn test_geglu_tensor_shape_2 ( ) {
557557 let device = Default :: default ( ) ;
558558 let weight = Tensor :: from_data (
559- Data :: from ( [
559+ TensorData :: from ( [
560560 [ 0.6054 , 1.9322 , 0.1445 , 1.3004 , -0.6853 , -0.8947 ] ,
561561 [ -0.3678 , 0.4081 , -1.9001 , -1.5843 , -0.9399 , 0.1018 ] ,
562562 ] ) ,
563563 & device,
564564 ) ;
565565 let bias = Tensor :: from_data (
566- Data :: from ( [
566+ TensorData :: from ( [
567567 0.3237631905393836 ,
568568 0.22052049807936902 ,
569569 -0.3196353346822061 ,
@@ -582,17 +582,17 @@ mod tests {
582582 } ;
583583
584584 let tensor: Tensor < TestBackend , 2 > =
585- Tensor :: from_data ( Data :: from ( [ [ 1. , 2. ] , [ 3. , 4. ] , [ 5. , 6. ] ] ) , & device) ;
585+ Tensor :: from_data ( TensorData :: from ( [ [ 1. , 2. ] , [ 3. , 4. ] , [ 5. , 6. ] ] ) , & device) ;
586586
587587 let output = geglu. forward ( tensor) ;
588588 assert_eq ! ( output. shape( ) , Shape :: from( [ 3 , 3 ] ) ) ;
589- output. to_data ( ) . assert_approx_eq (
590- & Data :: from ( [
589+ output. into_data ( ) . assert_approx_eq :: < f32 > (
590+ & TensorData :: from ( [
591591 [ -2.4192e-5 , -3.3057e-2 , 2.8535e-1 ] ,
592592 [ -0.0000e0 , -2.0983e-7 , 5.2465e-1 ] ,
593593 [ -0.0000e0 , -0.0000e0 , 1.2599e-2 ] ,
594594 ] ) ,
595- 1 ,
595+ Tolerance :: rel_abs ( 1e-1 , 1e-1 ) ,
596596 ) ;
597597 }
598598
@@ -601,7 +601,7 @@ mod tests {
601601 let device = Default :: default ( ) ;
602602 // create tensor of size [2, 4, 2]
603603 let query: Tensor < TestBackend , 3 > = Tensor :: from_data (
604- Data :: from ( [
604+ TensorData :: from ( [
605605 [ [ 1.0 , 2.0 ] , [ 3.0 , 4.0 ] , [ 5.0 , 6.0 ] , [ 7.0 , 8.0 ] ] ,
606606 [ [ 9.0 , 10.0 ] , [ 11.0 , 12.0 ] , [ 13.0 , 14.0 ] , [ 15.0 , 16.0 ] ] ,
607607 [ [ 17.0 , 18.0 ] , [ 19.0 , 20.0 ] , [ 21.0 , 22.0 ] , [ 23.0 , 24.0 ] ] ,
@@ -610,7 +610,7 @@ mod tests {
610610 & device,
611611 ) ;
612612 let key: Tensor < TestBackend , 3 > = Tensor :: from_data (
613- Data :: from ( [
613+ TensorData :: from ( [
614614 [ [ 1.0 , 2.0 ] , [ 3.0 , 4.0 ] , [ 5.0 , 6.0 ] , [ 7.0 , 8.0 ] ] ,
615615 [ [ 9.0 , 10.0 ] , [ 11.0 , 12.0 ] , [ 13.0 , 14.0 ] , [ 15.0 , 16.0 ] ] ,
616616 [ [ 17.0 , 18.0 ] , [ 19.0 , 20.0 ] , [ 21.0 , 22.0 ] , [ 23.0 , 24.0 ] ] ,
@@ -619,7 +619,7 @@ mod tests {
619619 & device,
620620 ) ;
621621 let value: Tensor < TestBackend , 3 > = Tensor :: from_data (
622- Data :: from ( [
622+ TensorData :: from ( [
623623 [ [ 1.0 , 2.0 ] , [ 3.0 , 4.0 ] , [ 5.0 , 6.0 ] , [ 7.0 , 8.0 ] ] ,
624624 [ [ 9.0 , 10.0 ] , [ 11.0 , 12.0 ] , [ 13.0 , 14.0 ] , [ 15.0 , 16.0 ] ] ,
625625 [ [ 17.0 , 18.0 ] , [ 19.0 , 20.0 ] , [ 21.0 , 22.0 ] , [ 23.0 , 24.0 ] ] ,
@@ -637,8 +637,8 @@ mod tests {
637637 let output = cross_attention. sliced_attention ( query, key, value, 2 ) ;
638638
639639 assert_eq ! ( output. shape( ) , Shape :: from( [ 2 , 4 , 4 ] ) ) ;
640- output. into_data ( ) . assert_approx_eq (
641- & Data :: from ( [
640+ output. into_data ( ) . assert_approx_eq :: < f32 > (
641+ & TensorData :: from ( [
642642 [
643643 [ 5.9201 , 6.9201 , 14.9951 , 15.9951 ] ,
644644 [ 6.7557 , 7.7557 , 14.9986 , 15.9986 ] ,
@@ -652,7 +652,7 @@ mod tests {
652652 [ 23.0000 , 24.0000 , 31.0000 , 32.0000 ] ,
653653 ] ,
654654 ] ) ,
655- 3 ,
655+ Tolerance :: rel_abs ( 1e-3 , 1e-3 ) ,
656656 )
657657 }
658658}
0 commit comments