Skip to content

Commit bae1e87

Browse files
committed
Refactor, fix install/setup?
1 parent 1b3257b commit bae1e87

31 files changed

+201
-164
lines changed

dreambooth/dataclasses/db_config.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,11 @@
88

99
from pydantic import BaseModel
1010

11-
from dreambooth import shared # noqa
12-
from dreambooth.dataclasses.db_concept import Concept # noqa
13-
from dreambooth.dataclasses.ss_model_spec import build_metadata
14-
from dreambooth.utils.image_utils import get_scheduler_names # noqa
15-
from dreambooth.utils.utils import list_attention, select_precision, select_attention
11+
from extensions.sd_dreambooth_extension.dreambooth import shared # noqa
12+
from extensions.sd_dreambooth_extension.dreambooth.dataclasses.db_concept import Concept # noqa
13+
from extensions.sd_dreambooth_extension.dreambooth.dataclasses.ss_model_spec import build_metadata
14+
from extensions.sd_dreambooth_extension.dreambooth.utils.image_utils import get_scheduler_names # noqa
15+
from extensions.sd_dreambooth_extension.dreambooth.utils.utils import list_attention, select_precision, select_attention
1616

1717
# Keys to save, replacing our dumb __init__ method
1818
save_keys = []

dreambooth/dataclasses/train_result.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from PIL import Image
22

3-
from dreambooth.dataclasses.db_config import DreamboothConfig
3+
from extensions.sd_dreambooth_extension.dreambooth.dataclasses.db_config import DreamboothConfig
44

55

66
class TrainResult:

dreambooth/dataset/bucket_sampler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import random
22
from typing import Tuple
33

4-
from dreambooth.dataset.db_dataset import DbDataset
4+
from extensions.sd_dreambooth_extension.dreambooth.dataset.db_dataset import DbDataset
55

66

77
class BucketSampler:

dreambooth/dataset/class_dataset.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,11 @@
33

44
from torch.utils.data import Dataset
55

6-
from dreambooth import shared
7-
from dreambooth.dataclasses.db_concept import Concept
8-
from dreambooth.dataclasses.prompt_data import PromptData
9-
from dreambooth.shared import status
10-
from dreambooth.utils.image_utils import FilenameTextGetter, \
6+
from extensions.sd_dreambooth_extension.dreambooth import shared
7+
from extensions.sd_dreambooth_extension.dreambooth.dataclasses.db_concept import Concept
8+
from extensions.sd_dreambooth_extension.dreambooth.dataclasses.prompt_data import PromptData
9+
from extensions.sd_dreambooth_extension.dreambooth.shared import status
10+
from extensions.sd_dreambooth_extension.dreambooth.utils.image_utils import FilenameTextGetter, \
1111
make_bucket_resolutions, \
1212
sort_prompts, get_images
1313
from helpers.mytqdm import mytqdm

dreambooth/dataset/db_dataset.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,12 @@
1010
from torchvision.transforms import transforms
1111
from transformers import CLIPTokenizer
1212

13-
from dreambooth import shared
14-
from dreambooth.dataclasses.prompt_data import PromptData
15-
from dreambooth.shared import status
16-
from dreambooth.utils.image_utils import make_bucket_resolutions, \
13+
from extensions.sd_dreambooth_extension.dreambooth import shared
14+
from extensions.sd_dreambooth_extension.dreambooth.dataclasses.prompt_data import PromptData
15+
from extensions.sd_dreambooth_extension.dreambooth.shared import status
16+
from extensions.sd_dreambooth_extension.dreambooth.utils.image_utils import make_bucket_resolutions, \
1717
closest_resolution, shuffle_tags, open_and_trim
18-
from dreambooth.utils.text_utils import build_strict_tokens
18+
from extensions.sd_dreambooth_extension.dreambooth.utils.text_utils import build_strict_tokens
1919
from helpers.mytqdm import mytqdm
2020

2121
logger = logging.getLogger(__name__)

dreambooth/dataset/sample_dataset.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33

44
from PIL import Image
55

6-
from dreambooth.dataclasses.db_config import DreamboothConfig
7-
from dreambooth.dataclasses.prompt_data import PromptData
8-
from dreambooth.utils.image_utils import get_images, FilenameTextGetter, \
6+
from extensions.sd_dreambooth_extension.dreambooth.dataclasses.db_config import DreamboothConfig
7+
from extensions.sd_dreambooth_extension.dreambooth.dataclasses.prompt_data import PromptData
8+
from extensions.sd_dreambooth_extension.dreambooth.utils.image_utils import get_images, FilenameTextGetter, \
99
closest_resolution, make_bucket_resolutions
1010

1111

dreambooth/diff_to_sd.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,13 @@
1515
from diffusers import UNet2DConditionModel
1616
from torch import Tensor, nn
1717

18-
from dreambooth import shared as shared
19-
from dreambooth.dataclasses.db_config import from_file, DreamboothConfig
20-
from dreambooth.shared import status
21-
from dreambooth.utils.model_utils import unload_system_models, \
18+
from extensions.sd_dreambooth_extension.dreambooth import shared as shared
19+
from extensions.sd_dreambooth_extension.dreambooth.dataclasses.db_config import from_file, DreamboothConfig
20+
from extensions.sd_dreambooth_extension.dreambooth.shared import status
21+
from extensions.sd_dreambooth_extension.dreambooth.utils.model_utils import unload_system_models, \
2222
reload_system_models, \
2323
safe_unpickle_disabled, import_model_class_from_model_name_or_path
24-
from dreambooth.utils.utils import printi
24+
from extensions.sd_dreambooth_extension.dreambooth.utils.utils import printi
2525
from helpers.mytqdm import mytqdm
2626
from lora_diffusion.lora import merge_lora_to_model
2727

dreambooth/diff_to_sdxl.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,11 @@
1111
import torch
1212
from safetensors.torch import load_file, save_file
1313

14-
from dreambooth import shared
15-
from dreambooth.dataclasses.db_config import from_file
16-
from dreambooth.shared import status
17-
from dreambooth.utils.model_utils import unload_system_models, reload_system_models
18-
from dreambooth.utils.utils import printi
14+
from extensions.sd_dreambooth_extension.dreambooth import shared
15+
from extensions.sd_dreambooth_extension.dreambooth.dataclasses.db_config import from_file
16+
from extensions.sd_dreambooth_extension.dreambooth.shared import status
17+
from extensions.sd_dreambooth_extension.dreambooth.utils.model_utils import unload_system_models, reload_system_models
18+
from extensions.sd_dreambooth_extension.dreambooth.utils.utils import printi
1919
from helpers import mytqdm
2020

2121
# =================#

dreambooth/memory.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@
2525
import torch
2626
import torch.backends.cudnn
2727

28-
from dreambooth import shared
29-
from dreambooth.utils.utils import cleanup
28+
from extensions.sd_dreambooth_extension.dreambooth import shared
29+
from extensions.sd_dreambooth_extension.dreambooth.utils.utils import cleanup
3030

3131

3232
def should_reduce_batch_size(exception: Exception) -> bool:

dreambooth/sd_to_diff.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,13 @@
2020
import shutil
2121
import traceback
2222
from typing import Union
23-
import torch
23+
2424
from diffusers import StableDiffusionXLPipeline, StableDiffusionPipeline
2525

26-
from dreambooth import shared
27-
from dreambooth.dataclasses.db_config import DreamboothConfig
28-
from dreambooth.utils.model_utils import safe_unpickle_disabled, unload_system_models, \
26+
from extensions.sd_dreambooth_extension.dreambooth import shared
27+
from extensions.sd_dreambooth_extension.dreambooth.dataclasses.db_config import DreamboothConfig
28+
from extensions.sd_dreambooth_extension.dreambooth.utils.model_utils import safe_unpickle_disabled, \
29+
unload_system_models, \
2930
reload_system_models
3031

3132

0 commit comments

Comments
 (0)