Skip to content

Commit 7e2cf80

Browse files
committed
Update to support elementwiseAffine for RMS, reflect for pad.
1 parent 8c5e290 commit 7e2cf80

File tree

3 files changed

+18
-11
lines changed

3 files changed

+18
-11
lines changed

WORKSPACE

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@ load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
55

66
git_repository(
77
name = "ccv",
8-
commit = "8f80a55f494c63d5ff333aff6969395ca156f04e",
8+
commit = "12b2355e882a8de1e15c1520ed16866694948ab1",
99
remote = "https://github.com/liuliu/ccv.git",
10-
shallow_since = "1766523761 -0500",
10+
shallow_since = "1768411247 -0500",
1111
)
1212

1313
load("@ccv//config:ccv.bzl", "ccv_deps", "ccv_setting")

deps.bzl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@ def s4nnc_deps():
1717
git_repository,
1818
name = "ccv",
1919
remote = "https://github.com/liuliu/ccv.git",
20-
commit = "8f80a55f494c63d5ff333aff6969395ca156f04e",
21-
shallow_since = "1766523761 -0500",
20+
commit = "12b2355e882a8de1e15c1520ed16866694948ab1",
21+
shallow_since = "1768411247 -0500",
2222
)
2323

2424
_maybe(

nnc/ModelAddons.swift

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -270,15 +270,16 @@ extension ModelIOConvertible {
270270
public final class Pad: Model {
271271
public enum Mode {
272272
case zero
273-
case replication
274-
public static let clampToEdge: Mode = .replication
273+
case replicate
274+
case reflect
275+
public static let clampToEdge: Mode = .replicate
275276
}
276277
required init(_ model: OpaquePointer) {
277278
super.init(model)
278279
}
279280

280281
public init(
281-
_ mode: Mode, begin: TensorShape, end: TensorShape, name: String = ""
282+
_ mode: Mode, begin: TensorShape = [], end: TensorShape = [], name: String = ""
282283
) {
283284
var begin = begin.dims
284285
var end = end.dims
@@ -290,8 +291,10 @@ public final class Pad: Model {
290291
switch mode {
291292
case .zero:
292293
type = Int32(CCV_NNC_PAD_ZERO)
293-
case .replication:
294+
case .replicate:
294295
type = Int32(CCV_NNC_PAD_REPLICATE)
296+
case .reflect:
297+
type = Int32(CCV_NNC_PAD_REFLECT)
295298
}
296299
return ccv_cnnp_pad(type, begin, end, name)!
297300
}
@@ -308,7 +311,7 @@ extension ModelIOConvertible {
308311
* - begin: The beginning pad for each dimension.
309312
* - end: The end pad for each dimension.
310313
*/
311-
public func padded(_ mode: Pad.Mode, begin: TensorShape, end: TensorShape) -> Model.IO {
314+
public func padded(_ mode: Pad.Mode, begin: TensorShape = [], end: TensorShape = []) -> Model.IO {
312315
return Pad(mode, begin: begin, end: end)(self)
313316
}
314317
}
@@ -715,11 +718,15 @@ public final class RMSNorm: Model {
715718
super.init(model)
716719
}
717720

718-
public init(epsilon: Float, axis: [Int], trainable: Bool? = nil, name: String = "") {
721+
public init(
722+
epsilon: Float, axis: [Int], elementwiseAffine: Bool = true, trainable: Bool? = nil,
723+
name: String = ""
724+
) {
719725
let axis32: [Int32] = axis.map { Int32($0) }
720726
super.init(
721727
ccv_cnnp_rmsnorm(
722-
epsilon, axis32, Int32(axis.count), trainable == true ? 1 : (trainable == false ? 0 : -1),
728+
epsilon, axis32, Int32(axis.count), elementwiseAffine ? 1 : 0,
729+
trainable == true ? 1 : (trainable == false ? 0 : -1),
723730
name))
724731
}
725732

0 commit comments

Comments
 (0)