Skip to content

Commit 94dfc6b

Browse files
Support for SDXL refiner (#227)
* Initial support for SDXL refiner * Cleanup * Add arg for converting Unet in float32 precision * Setup scale factor with pipeline in CLI * Update cli arg and future warning * Bundle refiner unet if specified * Update script for bundled refiner - Also skip loading model if check_output_correctness is missing, since the model does not require inferencing at conversion time * Flip skip_model_load bool * Cleanup * Support bundled UnetRefiner * Add seperate refiner config value - Includes unloading base unet when swapping to refiner * Update readme for SDXL refiner * Add condition for new SDXL coreml input features * Revert pipeline interface change, add extra logging on pipe load * Reset model_version after refiner conversion * Reset model_version before refiner conversion but after pipe init * Add refiner chunking * Ensure unets are unloaded for reduceMemory true * Handle missing UnetRefiner.mlmodelc on pipeline load Co-authored-by: Pedro Cuenca <[email protected]> * Prewarm refiner on load, unload on complete * Force cpu_and_gpu for VAE until it can be fixed * Include output dtype of np.float32 for all conversions * Allow a custom VAE to be converted. * Revert hardcoded reduceMemory * Fix merge * Default chunking arg for --merge-chunks-in-pipeline-model when called from torch2coreml --------- Co-authored-by: Pedro Cuenca <[email protected]>
1 parent d229861 commit 94dfc6b

File tree

10 files changed

+467
-135
lines changed

10 files changed

+467
-135
lines changed

README.md

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -209,10 +209,11 @@ An example `<selected-recipe-string-key>` would be `"recipe_4.50_bit_mixedpalett
209209
e.g.:
210210

211211
```bash
212-
python -m python_coreml_stable_diffusion.torch2coreml --convert-unet --convert-vae-decoder --convert-text-encoder --xl-version --model-version stabilityai/stable-diffusion-xl-base-1.0 --bundle-resources-for-swift-cli --attention-implementation ORIGINAL -o <output-dir>
212+
python -m python_coreml_stable_diffusion.torch2coreml --convert-unet --convert-vae-decoder --convert-text-encoder --xl-version --model-version stabilityai/stable-diffusion-xl-base-1.0 --refiner-version stabilityai/stable-diffusion-xl-refiner-1.0 --bundle-resources-for-swift-cli --attention-implementation ORIGINAL -o <output-dir>
213213
```
214214

215215
- `--xl-version`: Additional argument to pass to the conversion script when specifying an XL model
216+
- `--refiner-version`: Additional argument to pass to the conversion script when specifying an XL refiner model, required for ["Ensemble of Expert Denoisers"](https://huggingface.co/docs/diffusers/main/en/api/pipelines/stable_diffusion/stable_diffusion_xl#1-ensemble-of-expert-denoisers) inference.
216217
- `--attention-implementation ORIGINAL` (recommended for `cpuAndGPU`)
217218
- Due to known float16 overflow issues in the VAE, it runs in float32 precision for now
218219

@@ -225,7 +226,7 @@ swift run StableDiffusionSample <prompt> --resource-path <output-mlpackages-dire
225226
```
226227

227228
- Only `--compute-units cpuAndGPU` is supported for now
228-
- Only the `base` model is supported, `refiner` model is not yet supported
229+
- Only the `base` model is required, `refiner` model is optional and will be used by default if provided in the resource directory
229230
- ControlNet for XL is not yet supported
230231

231232

@@ -365,6 +366,7 @@ This generally takes 15-20 minutes on an M1 MacBook Pro. Upon successful executi
365366

366367
- `--model-version`: The model version name as published on the [Hugging Face Hub](https://huggingface.co/models?search=stable-diffusion)
367368

369+
- `--refiner-version`: The refiner version name as published on the [Hugging Face Hub](https://huggingface.co/models?search=stable-diffusion). This is optional and if specified, this argument will convert and bundle the refiner unet alongside the model unet.
368370

369371
- `--bundle-resources-for-swift-cli`: Compiles all 4 models and bundles them along with necessary resources for text tokenization into `<output-mlpackages-directory>/Resources` which should provided as input to the Swift package. This flag is not necessary for the diffusers-based Python pipeline.
370372

@@ -439,7 +441,7 @@ This Swift package contains two products:
439441

440442
Both of these products require the Core ML models and tokenization resources to be supplied. When specifying resources via a directory path that directory must contain the following:
441443

442-
- `TextEncoder.mlmodelc` (text embedding model)
444+
- `TextEncoder.mlmodelc` or `TextEncoder2.mlmodelc (text embedding model)
443445
- `Unet.mlmodelc` or `UnetChunk1.mlmodelc` & `UnetChunk2.mlmodelc` (denoising autoencoder model)
444446
- `VAEDecoder.mlmodelc` (image decoder model)
445447
- `vocab.json` (tokenizer vocabulary file)
@@ -453,6 +455,10 @@ Optionally, it may also include the safety checker model that some versions of S
453455

454456
- `SafetyChecker.mlmodelc`
455457

458+
Optionally, for the SDXL refiner:
459+
460+
- `UnetRefiner.mlmodelc` (refiner unet model)
461+
456462
Optionally, for ControlNet:
457463

458464
- `ControlledUNet.mlmodelc` or `ControlledUnetChunk1.mlmodelc` & `ControlledUnetChunk2.mlmodelc` (enabled to receive ControlNet values)

python_coreml_stable_diffusion/torch2coreml.py

Lines changed: 145 additions & 58 deletions
Large diffs are not rendered by default.

swift/StableDiffusion/pipeline/CGImage+vImage.swift

Lines changed: 54 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import Foundation
55
import Accelerate
66
import CoreML
7+
import CoreGraphics
78

89
@available(iOS 16.0, macOS 13.0, *)
910
extension CGImage {
@@ -77,7 +78,7 @@ extension CGImage {
7778
else {
7879
throw ShapedArrayError.incorrectFormatsConvertingToShapedArray
7980
}
80-
81+
8182
var sourceImageBuffer = try vImage_Buffer(cgImage: self)
8283

8384
var mediumDestination = try vImage_Buffer(width: Int(width), height: Int(height), bitsPerPixel: mediumFormat.bitsPerPixel)
@@ -88,7 +89,7 @@ extension CGImage {
8889
nil,
8990
vImage_Flags(kvImagePrintDiagnosticsToConsole),
9091
nil)
91-
92+
9293
guard let converter = converter?.takeRetainedValue() else {
9394
throw ShapedArrayError.vImageConverterNotInitialized
9495
}
@@ -99,7 +100,7 @@ extension CGImage {
99100
var destinationR = try vImage_Buffer(width: Int(width), height: Int(height), bitsPerPixel: 8 * UInt32(MemoryLayout<Float>.size))
100101
var destinationG = try vImage_Buffer(width: Int(width), height: Int(height), bitsPerPixel: 8 * UInt32(MemoryLayout<Float>.size))
101102
var destinationB = try vImage_Buffer(width: Int(width), height: Int(height), bitsPerPixel: 8 * UInt32(MemoryLayout<Float>.size))
102-
103+
103104
var minFloat: [Float] = Array(repeating: minValue, count: 4)
104105
var maxFloat: [Float] = Array(repeating: maxValue, count: 4)
105106

@@ -125,7 +126,56 @@ extension CGImage {
125126
let imageData = redData + greenData + blueData
126127

127128
let shapedArray = MLShapedArray<Float32>(data: imageData, shape: [1, 3, self.height, self.width])
128-
129+
130+
return shapedArray
131+
}
132+
133+
private func normalizePixelValues(pixel: UInt8) -> Float {
134+
return (Float(pixel) / 127.5) - 1.0
135+
}
136+
137+
public func toRGBShapedArray(minValue: Float, maxValue: Float)
138+
throws -> MLShapedArray<Float32> {
139+
let image = self
140+
let width = image.width
141+
let height = image.height
142+
let alphaMaskValue: Float = minValue
143+
144+
guard let colorSpace = CGColorSpace(name: CGColorSpace.sRGB),
145+
let context = CGContext(data: nil, width: width, height: height, bitsPerComponent: 8, bytesPerRow: 4 * width, space: colorSpace, bitmapInfo: CGImageAlphaInfo.premultipliedLast.rawValue),
146+
let ptr = context.data?.bindMemory(to: UInt8.self, capacity: width * height * 4) else {
147+
return []
148+
}
149+
150+
context.draw(image, in: CGRect(x: 0, y: 0, width: width, height: height))
151+
152+
var redChannel = [Float](repeating: 0, count: width * height)
153+
var greenChannel = [Float](repeating: 0, count: width * height)
154+
var blueChannel = [Float](repeating: 0, count: width * height)
155+
156+
for y in 0..<height {
157+
for x in 0..<width {
158+
let i = 4 * (y * width + x)
159+
if ptr[i+3] == 0 {
160+
// Alpha mask for controlnets
161+
redChannel[y * width + x] = alphaMaskValue
162+
greenChannel[y * width + x] = alphaMaskValue
163+
blueChannel[y * width + x] = alphaMaskValue
164+
} else {
165+
redChannel[y * width + x] = normalizePixelValues(pixel: ptr[i])
166+
greenChannel[y * width + x] = normalizePixelValues(pixel: ptr[i+1])
167+
blueChannel[y * width + x] = normalizePixelValues(pixel: ptr[i+2])
168+
}
169+
}
170+
}
171+
172+
let colorShape = [1, 1, height, width]
173+
let redShapedArray = MLShapedArray<Float32>(scalars: redChannel, shape: colorShape)
174+
let greenShapedArray = MLShapedArray<Float32>(scalars: greenChannel, shape: colorShape)
175+
let blueShapedArray = MLShapedArray<Float32>(scalars: blueChannel, shape: colorShape)
176+
177+
let shapedArray = MLShapedArray<Float32>(concatenating: [redShapedArray, greenShapedArray, blueShapedArray], alongAxis: 1)
178+
129179
return shapedArray
130180
}
131181
}

swift/StableDiffusion/pipeline/Encoder.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ public struct Encoder: ResourceManaging {
9393

9494
var inputDescription: MLFeatureDescription {
9595
try! model.perform { model in
96-
model.modelDescription.inputDescriptionsByName["z"]!
96+
model.modelDescription.inputDescriptionsByName.first!.value
9797
}
9898
}
9999

swift/StableDiffusion/pipeline/StableDiffusionPipeline.Configuration.swift

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,14 @@ public struct PipelineConfiguration: Hashable {
2020
public var negativePrompt: String = ""
2121
/// Starting image for image2image or in-painting
2222
public var startingImage: CGImage? = nil
23-
//public var maskImage: CGImage? = nil
23+
/// Fraction of inference steps to be used in `.imageToImage` pipeline mode
24+
/// Must be between 0 and 1
25+
/// Higher values will result in greater transformation of the `startingImage`
2426
public var strength: Float = 1.0
27+
/// Fraction of inference steps to at which to start using the refiner unet if present in `textToImage` mode
28+
/// Must be between 0 and 1
29+
/// Higher values will result in fewer refiner steps
30+
public var refinerStart: Float = 0.8
2531
/// Number of images to generate
2632
public var imageCount: Int = 1
2733
/// Number of inference steps to perform
@@ -44,7 +50,19 @@ public struct PipelineConfiguration: Hashable {
4450
public var encoderScaleFactor: Float32 = 0.18215
4551
/// Scale factor to use on the latent before decoding
4652
public var decoderScaleFactor: Float32 = 0.18215
47-
53+
/// If `originalSize` is not the same as `targetSize` the image will appear to be down- or upsampled.
54+
/// Part of SDXL’s micro-conditioning as explained in section 2.2 of https://huggingface.co/papers/2307.01952.
55+
public var originalSize: Float32 = 1024
56+
/// `cropsCoordsTopLeft` can be used to generate an image that appears to be “cropped” from the position `cropsCoordsTopLeft` downwards.
57+
/// Favorable, well-centered images are usually achieved by setting `cropsCoordsTopLeft` to (0, 0).
58+
public var cropsCoordsTopLeft: Float32 = 0
59+
/// For most cases, `target_size` should be set to the desired height and width of the generated image.
60+
public var targetSize: Float32 = 1024
61+
/// Used to simulate an aesthetic score of the generated image by influencing the positive text condition.
62+
public var aestheticScore: Float32 = 6
63+
/// Can be used to simulate an aesthetic score of the generated image by influencing the negative text condition.
64+
public var negativeAestheticScore: Float32 = 2.5
65+
4866
/// Given the configuration, what mode will be used for generation
4967
public var mode: PipelineMode {
5068
guard startingImage != nil else {

swift/StableDiffusion/pipeline/StableDiffusionPipeline.swift

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,9 @@ public enum StableDiffusionRNG {
2626
}
2727

2828
public enum PipelineError: String, Swift.Error {
29+
case missingUnetInputs
2930
case startingImageProvidedWithoutEncoder
31+
case startingText2ImgWithoutTextEncoder
3032
case unsupportedOSVersion
3133
}
3234

swift/StableDiffusion/pipeline/StableDiffusionXL+Resources.swift

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@ public extension StableDiffusionXLPipeline {
1515
public let unetURL: URL
1616
public let unetChunk1URL: URL
1717
public let unetChunk2URL: URL
18+
public let unetRefinerURL: URL
19+
public let unetRefinerChunk1URL: URL
20+
public let unetRefinerChunk2URL: URL
1821
public let decoderURL: URL
1922
public let encoderURL: URL
2023
public let vocabURL: URL
@@ -26,6 +29,9 @@ public extension StableDiffusionXLPipeline {
2629
unetURL = baseURL.appending(path: "Unet.mlmodelc")
2730
unetChunk1URL = baseURL.appending(path: "UnetChunk1.mlmodelc")
2831
unetChunk2URL = baseURL.appending(path: "UnetChunk2.mlmodelc")
32+
unetRefinerURL = baseURL.appending(path: "UnetRefiner.mlmodelc")
33+
unetRefinerChunk1URL = baseURL.appending(path: "UnetRefinerChunk1.mlmodelc")
34+
unetRefinerChunk2URL = baseURL.appending(path: "UnetRefinerChunk2.mlmodelc")
2935
decoderURL = baseURL.appending(path: "VAEDecoder.mlmodelc")
3036
encoderURL = baseURL.appending(path: "VAEEncoder.mlmodelc")
3137
vocabURL = baseURL.appending(path: "vocab.json")
@@ -51,7 +57,12 @@ public extension StableDiffusionXLPipeline {
5157
/// Expect URL of each resource
5258
let urls = ResourceURLs(resourcesAt: baseURL)
5359
let tokenizer = try BPETokenizer(mergesAt: urls.mergesURL, vocabularyAt: urls.vocabURL)
54-
let textEncoder = TextEncoderXL(tokenizer: tokenizer, modelAt: urls.textEncoderURL, configuration: config)
60+
let textEncoder: TextEncoderXL?
61+
if FileManager.default.fileExists(atPath: urls.textEncoderURL.path) {
62+
textEncoder = TextEncoderXL(tokenizer: tokenizer, modelAt: urls.textEncoderURL, configuration: config)
63+
} else {
64+
textEncoder = nil
65+
}
5566

5667
// padToken is different in the second XL text encoder
5768
let tokenizer2 = try BPETokenizer(mergesAt: urls.mergesURL, vocabularyAt: urls.vocabURL, padToken: "!")
@@ -67,13 +78,29 @@ public extension StableDiffusionXLPipeline {
6778
unet = Unet(modelAt: urls.unetURL, configuration: config)
6879
}
6980

81+
// Refiner Unet model
82+
let unetRefiner: Unet?
83+
if FileManager.default.fileExists(atPath: urls.unetRefinerChunk1URL.path) &&
84+
FileManager.default.fileExists(atPath: urls.unetRefinerChunk2URL.path) {
85+
unetRefiner = Unet(chunksAt: [urls.unetRefinerChunk1URL, urls.unetRefinerChunk2URL],
86+
configuration: config)
87+
} else if FileManager.default.fileExists(atPath: urls.unetRefinerURL.path) {
88+
unetRefiner = Unet(modelAt: urls.unetRefinerURL, configuration: config)
89+
} else {
90+
unetRefiner = nil
91+
}
92+
93+
7094
// Image Decoder
71-
let decoder = Decoder(modelAt: urls.decoderURL, configuration: config)
72-
95+
// FIXME: Hardcoding to .cpuAndGPU since ANE doesn't support FLOAT32
96+
let vaeConfig = config.copy() as! MLModelConfiguration
97+
vaeConfig.computeUnits = .cpuAndGPU
98+
let decoder = Decoder(modelAt: urls.decoderURL, configuration: vaeConfig)
99+
73100
// Optional Image Encoder
74101
let encoder: Encoder?
75102
if FileManager.default.fileExists(atPath: urls.encoderURL.path) {
76-
encoder = Encoder(modelAt: urls.encoderURL, configuration: config)
103+
encoder = Encoder(modelAt: urls.encoderURL, configuration: vaeConfig)
77104
} else {
78105
encoder = nil
79106
}
@@ -83,6 +110,7 @@ public extension StableDiffusionXLPipeline {
83110
textEncoder: textEncoder,
84111
textEncoder2: textEncoder2,
85112
unet: unet,
113+
unetRefiner: unetRefiner,
86114
decoder: decoder,
87115
encoder: encoder,
88116
reduceMemory: reduceMemory

0 commit comments

Comments
 (0)