Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion toolkit/config_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -837,7 +837,9 @@ def __init__(self, **kwargs):
self.cache_text_embeddings: bool = kwargs.get('cache_text_embeddings', False)

self.standardize_images: bool = kwargs.get('standardize_images', False)

self.preserve_resolutions: bool = kwargs.get(
"preserve_resolutions", False
)
# https://albumentations.ai/docs/api_reference/augmentations/transforms
# augmentations are returned as a separate image and cannot currently be cached
self.augmentations: List[dict] = kwargs.get('augmentations', None)
Expand Down
69 changes: 39 additions & 30 deletions toolkit/dataloader_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,42 +237,52 @@ def setup_buckets(self: 'AiToolkitDataset', quiet=False):
file_item.crop_x = 0
file_item.crop_y = int(file_item.scale_to_height / 2 - resolution / 2)
elif not did_process_poi:
bucket_resolution = get_bucket_for_image_size(
width, height,
resolution=resolution,
divisibility=bucket_tolerance
)
if hasattr(self.dataset_config, 'preserve_resolutions') and self.dataset_config.preserve_resolutions:
# Don't resize at all
file_item.scale_to_width = width
file_item.scale_to_height = height
file_item.crop_width = width
file_item.crop_height = height
file_item.crop_x = 0
file_item.crop_y = 0

else:
bucket_resolution = get_bucket_for_image_size(
width, height,
resolution=resolution,
divisibility=bucket_tolerance
)

# Calculate scale factors for width and height
width_scale_factor = bucket_resolution["width"] / width
height_scale_factor = bucket_resolution["height"] / height
# Calculate scale factors for width and height
width_scale_factor = bucket_resolution["width"] / width
height_scale_factor = bucket_resolution["height"] / height

# Use the maximum of the scale factors to ensure both dimensions are scaled above the bucket resolution
max_scale_factor = max(width_scale_factor, height_scale_factor)
# Use the maximum of the scale factors to ensure both dimensions are scaled above the bucket resolution
max_scale_factor = max(width_scale_factor, height_scale_factor)

# round up
file_item.scale_to_width = int(math.ceil(width * max_scale_factor))
file_item.scale_to_height = int(math.ceil(height * max_scale_factor))
# round up
file_item.scale_to_width = int(math.ceil(width * max_scale_factor))
file_item.scale_to_height = int(math.ceil(height * max_scale_factor))

file_item.crop_height = bucket_resolution["height"]
file_item.crop_width = bucket_resolution["width"]
file_item.crop_height = bucket_resolution["height"]
file_item.crop_width = bucket_resolution["width"]

new_width = bucket_resolution["width"]
new_height = bucket_resolution["height"]
new_width = bucket_resolution["width"]
new_height = bucket_resolution["height"]

if self.dataset_config.random_crop:
# random crop
crop_x = random.randint(0, file_item.scale_to_width - new_width)
crop_y = random.randint(0, file_item.scale_to_height - new_height)
file_item.crop_x = crop_x
file_item.crop_y = crop_y
else:
# do central crop
file_item.crop_x = int((file_item.scale_to_width - new_width) / 2)
file_item.crop_y = int((file_item.scale_to_height - new_height) / 2)
if self.dataset_config.random_crop:
# random crop
crop_x = random.randint(0, file_item.scale_to_width - new_width)
crop_y = random.randint(0, file_item.scale_to_height - new_height)
file_item.crop_x = crop_x
file_item.crop_y = crop_y
else:
# do central crop
file_item.crop_x = int((file_item.scale_to_width - new_width) / 2)
file_item.crop_y = int((file_item.scale_to_height - new_height) / 2)

if file_item.crop_y < 0 or file_item.crop_x < 0:
print_acc('debug')
if file_item.crop_y < 0 or file_item.crop_x < 0:
print_acc('debug')

# check if bucket exists, if not, create it
bucket_key = f'{file_item.crop_width}x{file_item.crop_height}'
Expand All @@ -289,7 +299,6 @@ def setup_buckets(self: 'AiToolkitDataset', quiet=False):
print_acc(f'{key}: {len(bucket.file_list_idx)} files')
print_acc(f'{len(self.buckets)} buckets made')


class CaptionProcessingDTOMixin:
def __init__(self: 'FileItemDTO', *args, **kwargs):
if hasattr(super(), '__init__'):
Expand Down
2 changes: 1 addition & 1 deletion toolkit/print.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def print_acc(*args, **kwargs):
class Logger:
def __init__(self, filename):
self.terminal = sys.stdout
self.log = open(filename, 'a')
self.log = open(filename, 'a', encoding='utf-8')

def write(self, message):
self.terminal.write(message)
Expand Down
28 changes: 18 additions & 10 deletions ui/src/app/jobs/new/SimpleJob.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ export default function SimpleJob({
count += 1; // add quantization card
}
return count;

}, [modelArch]);

let topBarClass = 'grid grid-cols-1 md:grid-cols-2 lg:grid-cols-4 xl:grid-cols-4 gap-6';
Expand Down Expand Up @@ -78,7 +77,7 @@ export default function SimpleJob({
let ARAs: SelectOption[] = [];
if (modelArch.accuracyRecoveryAdapters) {
for (const [label, value] of Object.entries(modelArch.accuracyRecoveryAdapters)) {
ARAs.push({ value, label });
ARAs.push({ value, label });
}
}
if (ARAs.length > 0) {
Expand Down Expand Up @@ -270,14 +269,14 @@ export default function SimpleJob({
/>
</FormGroup>
<NumberInput
label="Switch Every"
value={jobConfig.config.process[0].train.switch_boundary_every}
onChange={value => setJobConfig(value, 'config.process[0].train.switch_boundary_every')}
placeholder="eg. 1"
docKey={'train.switch_boundary_every'}
min={1}
required
/>
label="Switch Every"
value={jobConfig.config.process[0].train.switch_boundary_every}
onChange={value => setJobConfig(value, 'config.process[0].train.switch_boundary_every')}
placeholder="eg. 1"
docKey={'train.switch_boundary_every'}
min={1}
required
/>
</Card>
)}
<Card title="Target">
Expand Down Expand Up @@ -638,6 +637,14 @@ export default function SimpleJob({
docKey="datasets.do_i2v"
/>
)}
<Checkbox
label="Preserve Resolutions"
checked={dataset.preserve_resolutions || false}
onChange={value =>
setJobConfig(value, `config.process[0].datasets[${i}].preserve_resolutions`)
}
docKey="datasets.preserve_resolutions"
/>
</FormGroup>
<FormGroup label="Flipping" docKey={'datasets.flip'} className="mt-2">
<Checkbox
Expand All @@ -663,6 +670,7 @@ export default function SimpleJob({
{resGroup.map(res => (
<Checkbox
key={res}
disabled={dataset.preserve_resolutions}
label={res.toString()}
checked={dataset.resolution.includes(res)}
onChange={value => {
Expand Down
15 changes: 13 additions & 2 deletions ui/src/docs.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,17 @@ const docs: { [key: string]: ConfigDoc } = {
</>
),
},
'datasets.preserve_resolutions': {
title: 'Preserve Resolutions',
description: (
<>
This disables any kind of resizing or bucketing and will train your images at their original resolutions.
<br />
<br />
Any specified resolution settings will be ignored.
</>
),
},
'datasets.do_i2v': {
title: 'Do I2V',
description: (
Expand Down Expand Up @@ -132,8 +143,8 @@ const docs: { [key: string]: ConfigDoc } = {
Some models have multi stage networks that are trained and used separately in the denoising process. Most
common, is to have 2 stages. One for high noise and one for low noise. You can choose to train both stages at
once or train them separately. If trained at the same time, The trainer will alternate between training each
model every so many steps and will output 2 different LoRAs. If you choose to train only one stage, the
trainer will only train that stage and output a single LoRA.
model every so many steps and will output 2 different LoRAs. If you choose to train only one stage, the trainer
will only train that stage and output a single LoRA.
</>
),
},
Expand Down
1 change: 1 addition & 0 deletions ui/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ export interface DatasetConfig {
num_frames: number;
shrink_video_to_frames: boolean;
do_i2v: boolean;
preserve_resolutions: boolean;
flip_x: boolean;
flip_y: boolean;
}
Expand Down