@@ -1106,6 +1106,326 @@ def __call__(
11061106 return hidden_states , encoder_hidden_states
11071107
11081108
1109+ class PAGJointAttnProcessor2_0 :
1110+ """Attention processor used typically in processing the SD3-like self-attention projections."""
1111+
1112+ def __init__ (self ):
1113+ if not hasattr (F , "scaled_dot_product_attention" ):
1114+ raise ImportError (
1115+ "PAGJointAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
1116+ )
1117+
1118+ def __call__ (
1119+ self ,
1120+ attn : Attention ,
1121+ hidden_states : torch .FloatTensor ,
1122+ encoder_hidden_states : torch .FloatTensor = None ,
1123+ ) -> torch .FloatTensor :
1124+ residual = hidden_states
1125+
1126+ input_ndim = hidden_states .ndim
1127+ if input_ndim == 4 :
1128+ batch_size , channel , height , width = hidden_states .shape
1129+ hidden_states = hidden_states .view (batch_size , channel , height * width ).transpose (1 , 2 )
1130+ context_input_ndim = encoder_hidden_states .ndim
1131+ if context_input_ndim == 4 :
1132+ batch_size , channel , height , width = encoder_hidden_states .shape
1133+ encoder_hidden_states = encoder_hidden_states .view (batch_size , channel , height * width ).transpose (1 , 2 )
1134+
1135+ # store the length of image patch sequences to create a mask that prevents interaction between patches
1136+ # similar to making the self-attention map an identity matrix
1137+ identity_block_size = hidden_states .shape [1 ]
1138+
1139+ # chunk
1140+ hidden_states_org , hidden_states_ptb = hidden_states .chunk (2 )
1141+ encoder_hidden_states_org , encoder_hidden_states_ptb = encoder_hidden_states .chunk (2 )
1142+
1143+ ################## original path ##################
1144+ batch_size = encoder_hidden_states_org .shape [0 ]
1145+
1146+ # `sample` projections.
1147+ query_org = attn .to_q (hidden_states_org )
1148+ key_org = attn .to_k (hidden_states_org )
1149+ value_org = attn .to_v (hidden_states_org )
1150+
1151+ # `context` projections.
1152+ encoder_hidden_states_org_query_proj = attn .add_q_proj (encoder_hidden_states_org )
1153+ encoder_hidden_states_org_key_proj = attn .add_k_proj (encoder_hidden_states_org )
1154+ encoder_hidden_states_org_value_proj = attn .add_v_proj (encoder_hidden_states_org )
1155+
1156+ # attention
1157+ query_org = torch .cat ([query_org , encoder_hidden_states_org_query_proj ], dim = 1 )
1158+ key_org = torch .cat ([key_org , encoder_hidden_states_org_key_proj ], dim = 1 )
1159+ value_org = torch .cat ([value_org , encoder_hidden_states_org_value_proj ], dim = 1 )
1160+
1161+ inner_dim = key_org .shape [- 1 ]
1162+ head_dim = inner_dim // attn .heads
1163+ query_org = query_org .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
1164+ key_org = key_org .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
1165+ value_org = value_org .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
1166+
1167+ hidden_states_org = F .scaled_dot_product_attention (
1168+ query_org , key_org , value_org , dropout_p = 0.0 , is_causal = False
1169+ )
1170+ hidden_states_org = hidden_states_org .transpose (1 , 2 ).reshape (batch_size , - 1 , attn .heads * head_dim )
1171+ hidden_states_org = hidden_states_org .to (query_org .dtype )
1172+
1173+ # Split the attention outputs.
1174+ hidden_states_org , encoder_hidden_states_org = (
1175+ hidden_states_org [:, : residual .shape [1 ]],
1176+ hidden_states_org [:, residual .shape [1 ] :],
1177+ )
1178+
1179+ # linear proj
1180+ hidden_states_org = attn .to_out [0 ](hidden_states_org )
1181+ # dropout
1182+ hidden_states_org = attn .to_out [1 ](hidden_states_org )
1183+ if not attn .context_pre_only :
1184+ encoder_hidden_states_org = attn .to_add_out (encoder_hidden_states_org )
1185+
1186+ if input_ndim == 4 :
1187+ hidden_states_org = hidden_states_org .transpose (- 1 , - 2 ).reshape (batch_size , channel , height , width )
1188+ if context_input_ndim == 4 :
1189+ encoder_hidden_states_org = encoder_hidden_states_org .transpose (- 1 , - 2 ).reshape (
1190+ batch_size , channel , height , width
1191+ )
1192+
1193+ ################## perturbed path ##################
1194+
1195+ batch_size = encoder_hidden_states_ptb .shape [0 ]
1196+
1197+ # `sample` projections.
1198+ query_ptb = attn .to_q (hidden_states_ptb )
1199+ key_ptb = attn .to_k (hidden_states_ptb )
1200+ value_ptb = attn .to_v (hidden_states_ptb )
1201+
1202+ # `context` projections.
1203+ encoder_hidden_states_ptb_query_proj = attn .add_q_proj (encoder_hidden_states_ptb )
1204+ encoder_hidden_states_ptb_key_proj = attn .add_k_proj (encoder_hidden_states_ptb )
1205+ encoder_hidden_states_ptb_value_proj = attn .add_v_proj (encoder_hidden_states_ptb )
1206+
1207+ # attention
1208+ query_ptb = torch .cat ([query_ptb , encoder_hidden_states_ptb_query_proj ], dim = 1 )
1209+ key_ptb = torch .cat ([key_ptb , encoder_hidden_states_ptb_key_proj ], dim = 1 )
1210+ value_ptb = torch .cat ([value_ptb , encoder_hidden_states_ptb_value_proj ], dim = 1 )
1211+
1212+ inner_dim = key_ptb .shape [- 1 ]
1213+ head_dim = inner_dim // attn .heads
1214+ query_ptb = query_ptb .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
1215+ key_ptb = key_ptb .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
1216+ value_ptb = value_ptb .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
1217+
1218+ # create a full mask with all entries set to 0
1219+ seq_len = query_ptb .size (2 )
1220+ full_mask = torch .zeros ((seq_len , seq_len ), device = query_ptb .device , dtype = query_ptb .dtype )
1221+
1222+ # set the attention value between image patches to -inf
1223+ full_mask [:identity_block_size , :identity_block_size ] = float ("-inf" )
1224+
1225+ # set the diagonal of the attention value between image patches to 0
1226+ full_mask [:identity_block_size , :identity_block_size ].fill_diagonal_ (0 )
1227+
1228+ # expand the mask to match the attention weights shape
1229+ full_mask = full_mask .unsqueeze (0 ).unsqueeze (0 ) # Add batch and num_heads dimensions
1230+
1231+ hidden_states_ptb = F .scaled_dot_product_attention (
1232+ query_ptb , key_ptb , value_ptb , attn_mask = full_mask , dropout_p = 0.0 , is_causal = False
1233+ )
1234+ hidden_states_ptb = hidden_states_ptb .transpose (1 , 2 ).reshape (batch_size , - 1 , attn .heads * head_dim )
1235+ hidden_states_ptb = hidden_states_ptb .to (query_ptb .dtype )
1236+
1237+ # split the attention outputs.
1238+ hidden_states_ptb , encoder_hidden_states_ptb = (
1239+ hidden_states_ptb [:, : residual .shape [1 ]],
1240+ hidden_states_ptb [:, residual .shape [1 ] :],
1241+ )
1242+
1243+ # linear proj
1244+ hidden_states_ptb = attn .to_out [0 ](hidden_states_ptb )
1245+ # dropout
1246+ hidden_states_ptb = attn .to_out [1 ](hidden_states_ptb )
1247+ if not attn .context_pre_only :
1248+ encoder_hidden_states_ptb = attn .to_add_out (encoder_hidden_states_ptb )
1249+
1250+ if input_ndim == 4 :
1251+ hidden_states_ptb = hidden_states_ptb .transpose (- 1 , - 2 ).reshape (batch_size , channel , height , width )
1252+ if context_input_ndim == 4 :
1253+ encoder_hidden_states_ptb = encoder_hidden_states_ptb .transpose (- 1 , - 2 ).reshape (
1254+ batch_size , channel , height , width
1255+ )
1256+
1257+ ################ concat ###############
1258+ hidden_states = torch .cat ([hidden_states_org , hidden_states_ptb ])
1259+ encoder_hidden_states = torch .cat ([encoder_hidden_states_org , encoder_hidden_states_ptb ])
1260+
1261+ return hidden_states , encoder_hidden_states
1262+
1263+
1264+ class PAGCFGJointAttnProcessor2_0 :
1265+ """Attention processor used typically in processing the SD3-like self-attention projections."""
1266+
1267+ def __init__ (self ):
1268+ if not hasattr (F , "scaled_dot_product_attention" ):
1269+ raise ImportError (
1270+ "PAGCFGJointAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
1271+ )
1272+
1273+ def __call__ (
1274+ self ,
1275+ attn : Attention ,
1276+ hidden_states : torch .FloatTensor ,
1277+ encoder_hidden_states : torch .FloatTensor = None ,
1278+ attention_mask : Optional [torch .FloatTensor ] = None ,
1279+ * args ,
1280+ ** kwargs ,
1281+ ) -> torch .FloatTensor :
1282+ residual = hidden_states
1283+
1284+ input_ndim = hidden_states .ndim
1285+ if input_ndim == 4 :
1286+ batch_size , channel , height , width = hidden_states .shape
1287+ hidden_states = hidden_states .view (batch_size , channel , height * width ).transpose (1 , 2 )
1288+ context_input_ndim = encoder_hidden_states .ndim
1289+ if context_input_ndim == 4 :
1290+ batch_size , channel , height , width = encoder_hidden_states .shape
1291+ encoder_hidden_states = encoder_hidden_states .view (batch_size , channel , height * width ).transpose (1 , 2 )
1292+
1293+ identity_block_size = hidden_states .shape [
1294+ 1
1295+ ] # patch embeddings width * height (correspond to self-attention map width or height)
1296+
1297+ # chunk
1298+ hidden_states_uncond , hidden_states_org , hidden_states_ptb = hidden_states .chunk (3 )
1299+ hidden_states_org = torch .cat ([hidden_states_uncond , hidden_states_org ])
1300+
1301+ (
1302+ encoder_hidden_states_uncond ,
1303+ encoder_hidden_states_org ,
1304+ encoder_hidden_states_ptb ,
1305+ ) = encoder_hidden_states .chunk (3 )
1306+ encoder_hidden_states_org = torch .cat ([encoder_hidden_states_uncond , encoder_hidden_states_org ])
1307+
1308+ ################## original path ##################
1309+ batch_size = encoder_hidden_states_org .shape [0 ]
1310+
1311+ # `sample` projections.
1312+ query_org = attn .to_q (hidden_states_org )
1313+ key_org = attn .to_k (hidden_states_org )
1314+ value_org = attn .to_v (hidden_states_org )
1315+
1316+ # `context` projections.
1317+ encoder_hidden_states_org_query_proj = attn .add_q_proj (encoder_hidden_states_org )
1318+ encoder_hidden_states_org_key_proj = attn .add_k_proj (encoder_hidden_states_org )
1319+ encoder_hidden_states_org_value_proj = attn .add_v_proj (encoder_hidden_states_org )
1320+
1321+ # attention
1322+ query_org = torch .cat ([query_org , encoder_hidden_states_org_query_proj ], dim = 1 )
1323+ key_org = torch .cat ([key_org , encoder_hidden_states_org_key_proj ], dim = 1 )
1324+ value_org = torch .cat ([value_org , encoder_hidden_states_org_value_proj ], dim = 1 )
1325+
1326+ inner_dim = key_org .shape [- 1 ]
1327+ head_dim = inner_dim // attn .heads
1328+ query_org = query_org .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
1329+ key_org = key_org .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
1330+ value_org = value_org .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
1331+
1332+ hidden_states_org = F .scaled_dot_product_attention (
1333+ query_org , key_org , value_org , dropout_p = 0.0 , is_causal = False
1334+ )
1335+ hidden_states_org = hidden_states_org .transpose (1 , 2 ).reshape (batch_size , - 1 , attn .heads * head_dim )
1336+ hidden_states_org = hidden_states_org .to (query_org .dtype )
1337+
1338+ # Split the attention outputs.
1339+ hidden_states_org , encoder_hidden_states_org = (
1340+ hidden_states_org [:, : residual .shape [1 ]],
1341+ hidden_states_org [:, residual .shape [1 ] :],
1342+ )
1343+
1344+ # linear proj
1345+ hidden_states_org = attn .to_out [0 ](hidden_states_org )
1346+ # dropout
1347+ hidden_states_org = attn .to_out [1 ](hidden_states_org )
1348+ if not attn .context_pre_only :
1349+ encoder_hidden_states_org = attn .to_add_out (encoder_hidden_states_org )
1350+
1351+ if input_ndim == 4 :
1352+ hidden_states_org = hidden_states_org .transpose (- 1 , - 2 ).reshape (batch_size , channel , height , width )
1353+ if context_input_ndim == 4 :
1354+ encoder_hidden_states_org = encoder_hidden_states_org .transpose (- 1 , - 2 ).reshape (
1355+ batch_size , channel , height , width
1356+ )
1357+
1358+ ################## perturbed path ##################
1359+
1360+ batch_size = encoder_hidden_states_ptb .shape [0 ]
1361+
1362+ # `sample` projections.
1363+ query_ptb = attn .to_q (hidden_states_ptb )
1364+ key_ptb = attn .to_k (hidden_states_ptb )
1365+ value_ptb = attn .to_v (hidden_states_ptb )
1366+
1367+ # `context` projections.
1368+ encoder_hidden_states_ptb_query_proj = attn .add_q_proj (encoder_hidden_states_ptb )
1369+ encoder_hidden_states_ptb_key_proj = attn .add_k_proj (encoder_hidden_states_ptb )
1370+ encoder_hidden_states_ptb_value_proj = attn .add_v_proj (encoder_hidden_states_ptb )
1371+
1372+ # attention
1373+ query_ptb = torch .cat ([query_ptb , encoder_hidden_states_ptb_query_proj ], dim = 1 )
1374+ key_ptb = torch .cat ([key_ptb , encoder_hidden_states_ptb_key_proj ], dim = 1 )
1375+ value_ptb = torch .cat ([value_ptb , encoder_hidden_states_ptb_value_proj ], dim = 1 )
1376+
1377+ inner_dim = key_ptb .shape [- 1 ]
1378+ head_dim = inner_dim // attn .heads
1379+ query_ptb = query_ptb .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
1380+ key_ptb = key_ptb .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
1381+ value_ptb = value_ptb .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
1382+
1383+ # create a full mask with all entries set to 0
1384+ seq_len = query_ptb .size (2 )
1385+ full_mask = torch .zeros ((seq_len , seq_len ), device = query_ptb .device , dtype = query_ptb .dtype )
1386+
1387+ # set the attention value between image patches to -inf
1388+ full_mask [:identity_block_size , :identity_block_size ] = float ("-inf" )
1389+
1390+ # set the diagonal of the attention value between image patches to 0
1391+ full_mask [:identity_block_size , :identity_block_size ].fill_diagonal_ (0 )
1392+
1393+ # expand the mask to match the attention weights shape
1394+ full_mask = full_mask .unsqueeze (0 ).unsqueeze (0 ) # Add batch and num_heads dimensions
1395+
1396+ hidden_states_ptb = F .scaled_dot_product_attention (
1397+ query_ptb , key_ptb , value_ptb , attn_mask = full_mask , dropout_p = 0.0 , is_causal = False
1398+ )
1399+ hidden_states_ptb = hidden_states_ptb .transpose (1 , 2 ).reshape (batch_size , - 1 , attn .heads * head_dim )
1400+ hidden_states_ptb = hidden_states_ptb .to (query_ptb .dtype )
1401+
1402+ # split the attention outputs.
1403+ hidden_states_ptb , encoder_hidden_states_ptb = (
1404+ hidden_states_ptb [:, : residual .shape [1 ]],
1405+ hidden_states_ptb [:, residual .shape [1 ] :],
1406+ )
1407+
1408+ # linear proj
1409+ hidden_states_ptb = attn .to_out [0 ](hidden_states_ptb )
1410+ # dropout
1411+ hidden_states_ptb = attn .to_out [1 ](hidden_states_ptb )
1412+ if not attn .context_pre_only :
1413+ encoder_hidden_states_ptb = attn .to_add_out (encoder_hidden_states_ptb )
1414+
1415+ if input_ndim == 4 :
1416+ hidden_states_ptb = hidden_states_ptb .transpose (- 1 , - 2 ).reshape (batch_size , channel , height , width )
1417+ if context_input_ndim == 4 :
1418+ encoder_hidden_states_ptb = encoder_hidden_states_ptb .transpose (- 1 , - 2 ).reshape (
1419+ batch_size , channel , height , width
1420+ )
1421+
1422+ ################ concat ###############
1423+ hidden_states = torch .cat ([hidden_states_org , hidden_states_ptb ])
1424+ encoder_hidden_states = torch .cat ([encoder_hidden_states_org , encoder_hidden_states_ptb ])
1425+
1426+ return hidden_states , encoder_hidden_states
1427+
1428+
11091429class FusedJointAttnProcessor2_0 :
11101430 """Attention processor used typically in processing the SD3-like self-attention projections."""
11111431
0 commit comments