@@ -11,6 +11,7 @@ package ggml
1111import  "C" 
1212
1313import  (
14+ 	"cmp" 
1415	"context" 
1516	"errors" 
1617	"fmt" 
@@ -1410,43 +1411,75 @@ func (t *Tensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
14101411
14111412func  (t  * Tensor ) RoPE (ctx  ml.Context , positions  ml.Tensor , ropeDim  int , ropeBase , ropeScale  float32 , options  ... func (* rope.Options )) ml.Tensor  {
14121413	// Default options 
1413- 	opts  :=  rope.Options {
1414- 		Factors :               & Tensor {},
1415- 		OriginalContextLength : 131072 ,
1416- 		ExtrapolationFactor :   0. ,
1417- 		AttentionFactor :       1. ,
1418- 		BetaFast :              32. ,
1419- 		BetaSlow :              1. ,
1420- 	}
1414+ 	opts  :=  rope.Options {Factors : & Tensor {}}
14211415
14221416	// Apply any provided options 
14231417	for  _ , option  :=  range  options  {
14241418		option (& opts )
14251419	}
14261420
1421+ 	factors  :=  opts .Factors 
1422+ 	if  factors  ==  nil  {
1423+ 		factors  =  & Tensor {}
1424+ 	}
1425+ 
1426+ 	tensorFactors , ok  :=  factors .(* Tensor )
1427+ 	if  ! ok  {
1428+ 		panic ("ggml: unsupported tensor type for RoPE factors" )
1429+ 	}
1430+ 
14271431	dequant  :=  t .t 
14281432	if  C .ggml_is_quantized (t .t ._type ) {
14291433		dequant  =  C .ggml_cast (ctx .(* Context ).ctx , t .t , C .GGML_TYPE_F32 )
14301434	}
14311435
1432- 	return  & Tensor {
1433- 		b : t .b ,
1434- 		t : C .ggml_rope_ext (
1436+ 	originalContextLength  :=  cmp .Or (opts .YaRN .OriginalContextLength , 128 << 10 )
1437+ 	attentionFactor  :=  cmp .Or (opts .YaRN .AttentionFactor , float32 (1 ))
1438+ 	betaFast  :=  cmp .Or (opts .YaRN .BetaFast , float32 (32 ))
1439+ 	betaSlow  :=  cmp .Or (opts .YaRN .BetaSlow , float32 (1 ))
1440+ 
1441+ 	var  sections  * C.int 
1442+ 	if  len (opts .MRoPE .Sections ) >  0  {
1443+ 		sections  =  (* C .int )(unsafe .Pointer (& opts .MRoPE .Sections [0 ]))
1444+ 	}
1445+ 
1446+ 	var  tt  * C.struct_ggml_tensor 
1447+ 	if  opts .Type & 0b1000  !=  0  {
1448+ 		tt  =  C .ggml_rope_multi (
14351449			ctx .(* Context ).ctx ,
14361450			dequant ,
14371451			positions .(* Tensor ).t ,
1438- 			opts . Factors .( * Tensor ) .t ,
1452+ 			tensorFactors .t ,
14391453			C .int (ropeDim ),
1454+ 			sections ,
14401455			C .int (opts .Type ),
1441- 			C .int (opts . OriginalContextLength ),
1456+ 			C .int (originalContextLength ),
14421457			C .float (ropeBase ),
14431458			C .float (ropeScale ),
1444- 			C .float (opts .ExtrapolationFactor ),
1445- 			C .float (opts .AttentionFactor ),
1446- 			C .float (opts .BetaFast ),
1447- 			C .float (opts .BetaSlow ),
1448- 		),
1459+ 			C .float (opts .YaRN .ExtrapolationFactor ),
1460+ 			C .float (attentionFactor ),
1461+ 			C .float (betaFast ),
1462+ 			C .float (betaSlow ),
1463+ 		)
1464+ 	} else  {
1465+ 		tt  =  C .ggml_rope_ext (
1466+ 			ctx .(* Context ).ctx ,
1467+ 			dequant ,
1468+ 			positions .(* Tensor ).t ,
1469+ 			tensorFactors .t ,
1470+ 			C .int (ropeDim ),
1471+ 			C .int (opts .Type ),
1472+ 			C .int (originalContextLength ),
1473+ 			C .float (ropeBase ),
1474+ 			C .float (ropeScale ),
1475+ 			C .float (opts .YaRN .ExtrapolationFactor ),
1476+ 			C .float (attentionFactor ),
1477+ 			C .float (betaFast ),
1478+ 			C .float (betaSlow ),
1479+ 		)
14491480	}
1481+ 
1482+ 	return  & Tensor {b : t .b , t : tt }
14501483}
14511484
14521485func  (t  * Tensor ) IM2Col (ctx  ml.Context , t2  ml.Tensor , s0 , s1 , p0 , p1 , d0 , d1  int ) ml.Tensor  {
@@ -1509,6 +1542,27 @@ func (t *Tensor) Conv2D(ctx ml.Context, t2 ml.Tensor, s0, s1, p0, p1, d0, d1 int
15091542	}
15101543}
15111544
1545+ func  (t  * Tensor ) Conv3D (ctx  ml.Context , t2  ml.Tensor , ic , s0 , s1 , s2 , p0 , p1 , p2 , d0 , d1 , d2  int ) ml.Tensor  {
1546+ 	return  & Tensor {
1547+ 		b : t .b ,
1548+ 		t : C .ggml_conv_3d (
1549+ 			ctx .(* Context ).ctx ,
1550+ 			t .t ,
1551+ 			t2 .(* Tensor ).t ,
1552+ 			C .int64_t (ic ),
1553+ 			C .int (s0 ),
1554+ 			C .int (s1 ),
1555+ 			C .int (s2 ),
1556+ 			C .int (p0 ),
1557+ 			C .int (p1 ),
1558+ 			C .int (p2 ),
1559+ 			C .int (d0 ),
1560+ 			C .int (d1 ),
1561+ 			C .int (d2 ),
1562+ 		),
1563+ 	}
1564+ }
1565+ 
15121566func  (t  * Tensor ) AvgPool2D (ctx  ml.Context , k , s  int , p  float32 ) ml.Tensor  {
15131567	return  & Tensor {
15141568		b : t .b ,
0 commit comments