Skip to content

Commit a1dac6c

Browse files
committed
Change the API to mutate passed in keyValues (as "consume").
1 parent d7df218 commit a1dac6c

File tree

2 files changed

+12
-6
lines changed

2 files changed

+12
-6
lines changed

WORKSPACE

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@ load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
55

66
git_repository(
77
name = "ccv",
8-
commit = "479c846ca1aff4a313ef0d632f03e05f800f13ea",
8+
commit = "6fd6af5eb258e8bbf22daa02f81a7f900cffe0be",
99
remote = "https://github.com/liuliu/ccv.git",
10-
shallow_since = "1746828831 -0400",
10+
shallow_since = "1746841166 -0400",
1111
)
1212

1313
load("@ccv//config:ccv.bzl", "ccv_deps", "ccv_setting")

nnc/Model.swift

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -200,23 +200,29 @@ public class Model: AnyModel {
200200
/**
201201
* Move the backing tensor into the parameters. Making the original one invalid afterwards.
202202
*/
203-
public func attach(consuming keyValues: [(key: String, value: AnyTensor)]) {
203+
public func attach(consuming keyValues: inout [(key: String, value: AnyTensor)]) {
204204
let keys = keyValues.map { $0.key.utf8CString }
205205
let values = keyValues.map(\.value)
206-
withExtendedLifetime(keys) { keys in
206+
// Once we attached, we remove all the keyValues that are consumed.
207+
keyValues = withExtendedLifetime(keys) { keys in
207208
withExtendedLifetime(values) { values in
208209
var cTensors: [UnsafeMutablePointer<ccv_nnc_tensor_t>?] = values.map { $0.cTensor }
209210
var cKeys: [UnsafeMutablePointer<CChar>?] = keys.map {
210211
$0.withUnsafeBufferPointer { $0.baseAddress.map { UnsafeMutablePointer(mutating: $0) } }
211212
}
212-
cKeys.withUnsafeMutableBufferPointer { cKeys in
213+
return cKeys.withUnsafeMutableBufferPointer { cKeys in
213214
cTensors.withUnsafeMutableBufferPointer { cTensors in
214215
ccv_cnnp_model_set_parameters_from_key_values(
215216
model!.cModel, cKeys.baseAddress, cTensors.baseAddress, Int32(keyValues.count), 1)
217+
var newKeyValues = [(key: String, value: AnyTensor)]()
216218
for (i, value) in values.enumerated() {
217-
guard cTensors[i] == nil else { continue }
219+
guard cTensors[i] == nil else {
220+
newKeyValues.append(keyValues[i])
221+
continue
222+
}
218223
value.consume()
219224
}
225+
return newKeyValues
220226
}
221227
}
222228
}

0 commit comments

Comments
 (0)