@@ -270,15 +270,16 @@ extension ModelIOConvertible {
270270public final class Pad : Model {
271271 public enum Mode {
272272 case zero
273- case replication
274- public static let clampToEdge : Mode = . replication
273+ case replicate
274+ case reflect
275+ public static let clampToEdge : Mode = . replicate
275276 }
276277 required init ( _ model: OpaquePointer ) {
277278 super. init ( model)
278279 }
279280
280281 public init (
281- _ mode: Mode , begin: TensorShape , end: TensorShape , name: String = " "
282+ _ mode: Mode , begin: TensorShape = [ ] , end: TensorShape = [ ] , name: String = " "
282283 ) {
283284 var begin = begin. dims
284285 var end = end. dims
@@ -290,8 +291,10 @@ public final class Pad: Model {
290291 switch mode {
291292 case . zero:
292293 type = Int32 ( CCV_NNC_PAD_ZERO)
293- case . replication :
294+ case . replicate :
294295 type = Int32 ( CCV_NNC_PAD_REPLICATE)
296+ case . reflect:
297+ type = Int32 ( CCV_NNC_PAD_REFLECT)
295298 }
296299 return ccv_cnnp_pad ( type, begin, end, name) !
297300 }
@@ -308,7 +311,7 @@ extension ModelIOConvertible {
308311 * - begin: The beginning pad for each dimension.
309312 * - end: The end pad for each dimension.
310313 */
311- public func padded( _ mode: Pad . Mode , begin: TensorShape , end: TensorShape ) -> Model . IO {
314+ public func padded( _ mode: Pad . Mode , begin: TensorShape = [ ] , end: TensorShape = [ ] ) -> Model . IO {
312315 return Pad ( mode, begin: begin, end: end) ( self )
313316 }
314317}
@@ -715,11 +718,15 @@ public final class RMSNorm: Model {
715718 super. init ( model)
716719 }
717720
718- public init ( epsilon: Float , axis: [ Int ] , trainable: Bool ? = nil , name: String = " " ) {
721+ public init (
722+ epsilon: Float , axis: [ Int ] , elementwiseAffine: Bool = true , trainable: Bool ? = nil ,
723+ name: String = " "
724+ ) {
719725 let axis32 : [ Int32 ] = axis. map { Int32 ( $0) }
720726 super. init (
721727 ccv_cnnp_rmsnorm (
722- epsilon, axis32, Int32 ( axis. count) , trainable == true ? 1 : ( trainable == false ? 0 : - 1 ) ,
728+ epsilon, axis32, Int32 ( axis. count) , elementwiseAffine ? 1 : 0 ,
729+ trainable == true ? 1 : ( trainable == false ? 0 : - 1 ) ,
723730 name) )
724731 }
725732
0 commit comments