@@ -834,16 +834,43 @@ namespace Flux {
834834 FluxRunner (ggml_backend_t backend,
835835 std::map<std::string, enum ggml_type>& tensor_types = empty_tensor_types,
836836 const std::string prefix = " " ,
837- SDVersion version = VERSION_FLUX_DEV,
838- bool flash_attn = false )
837+ bool flash_attn = false )
839838 : GGMLRunner(backend) {
840- flux_params.flash_attn = flash_attn;
841- if (version == VERSION_FLUX_SCHNELL) {
842- flux_params.guidance_embed = false ;
839+ flux_params.flash_attn = flash_attn;
840+ flux_params.guidance_embed = false ;
841+ flux_params.depth = 0 ;
842+ flux_params.depth_single_blocks = 0 ;
843+ for (auto pair : tensor_types) {
844+ std::string tensor_name = pair.first ;
845+ if (tensor_name.find (" model.diffusion_model." ) == std::string::npos)
846+ continue ;
847+ if (tensor_name.find (" guidance_in.in_layer.weight" ) != std::string::npos) {
848+ // not schnell
849+ flux_params.guidance_embed = true ;
850+ }
851+ size_t db = tensor_name.find (" double_blocks." );
852+ if (db != std::string::npos) {
853+ tensor_name = tensor_name.substr (db); // remove prefix
854+ int block_depth = atoi (tensor_name.substr (14 , tensor_name.find (" ." , 14 )).c_str ());
855+ if (block_depth + 1 > flux_params.depth ) {
856+ flux_params.depth = block_depth + 1 ;
857+ }
858+ }
859+ size_t sb = tensor_name.find (" single_blocks." );
860+ if (sb != std::string::npos) {
861+ tensor_name = tensor_name.substr (sb); // remove prefix
862+ int block_depth = atoi (tensor_name.substr (14 , tensor_name.find (" ." , 14 )).c_str ());
863+ if (block_depth + 1 > flux_params.depth_single_blocks ) {
864+ flux_params.depth_single_blocks = block_depth + 1 ;
865+ }
866+ }
843867 }
844- if (version == VERSION_FLUX_LITE) {
845- flux_params.depth = 8 ;
868+
869+ LOG_INFO (" Flux blocks: %d double, %d single" , flux_params.depth , flux_params.depth_single_blocks );
870+ if (!flux_params.guidance_embed ) {
871+ LOG_INFO (" Flux guidance is disabled (Schnell mode)" );
846872 }
873+
847874 flux = Flux (flux_params);
848875 flux.init (params_ctx, tensor_types, prefix);
849876 }
0 commit comments