@@ -325,13 +325,132 @@ def diff_2_4(a, i_1, i_2, j, dic, matrices, attn_1):
325
325
return t_4
326
326
327
327
328
+ def diff_2_3_4 (a , i_1 , i_2 , j , dic , matrices , attn_1 ):
329
+
330
+ (term_0 , term_1 , term_2 , term_3 , term_4 , term_5 , term_6 , term_7 , term_8 ) = matrices
331
+
332
+ if j == i_1 :
333
+ return 0
334
+ for k in range (i_2 + 1 ):
335
+ if j != 0 and j != 1 :
336
+ c = (
337
+ term_4 [k , dic [k ], j - 1 ][..., dic [j - 1 ]].max ()
338
+ + term_3 [i_2 , a , j - 1 , dic [j - 1 ]].max ()
339
+ )
340
+ # new = c.clone()
341
+ d = c * attn_1 [dic [j ], j - 1 ].min ()
342
+
343
+ for i in range (0 , j - 1 ):
344
+
345
+ c = torch .max (
346
+ c ,
347
+ term_4 [k , dic [k ], i ][..., dic [i ]].max ()
348
+ + term_3 [i_2 , dic [i_2 ], i , dic [i ]].max (),
349
+ )
350
+ c = torch .max (
351
+ c ,
352
+ term_4 [k , dic [k ], j ][..., dic [j ]].max ()
353
+ + term_3 [i_2 , dic [i_2 ], j , dic [j ]].max (),
354
+ )
355
+ d = d + (1 - attn_1 [dic [j ], j - 1 ].min ()) * c
356
+
357
+ if j == 0 :
358
+
359
+ d = (
360
+ term_4 [k , dic [k ], j ][..., dic [j ]].max ()
361
+ + term_3 [i_2 , a , j , dic [j ]].max ()
362
+ )
363
+
364
+ if j == 1 :
365
+ c = (
366
+ term_4 [k , dic [k ], j - 1 ][..., dic [j - 1 ]].max ()
367
+ + term_3 [i_2 , a , j - 1 , dic [j - 1 ]].max ()
368
+ )
369
+ # new=c.clone()
370
+ d = c * attn_1 [dic [j ], j - 1 ].min ()
371
+ c = torch .max (
372
+ c ,
373
+ term_4 [k , dic [k ], j ][..., dic [j ]].max ()
374
+ + term_3 [i_2 , a , j , dic [j ]].max (),
375
+ )
376
+ d = d + (1 - attn_1 [dic [j ], j - 1 ].min ()) * c
377
+
378
+ # print(d)
379
+ if i_1 != 1 :
380
+ c = term_4 [k , dic [k ], i_1 - 1 , a ].min () + term_3 [i_2 , a , i_1 - 1 , a ]
381
+ # new=c.clone()
382
+ d = d - attn_1 [dic [i_1 ], i_1 - 1 ].min () * c
383
+
384
+ for i in range (0 , i_1 - 1 ):
385
+
386
+ c = torch .min (
387
+ c ,
388
+ term_4 [k , dic [k ], i ][..., dic [i ]].min ()
389
+ + term_3 [i_2 , dic [i_2 ], i , dic [i ]].min (),
390
+ )
391
+ c = torch .min (
392
+ c ,
393
+ term_4 [k , dic [k ], i_1 ][..., dic [i_1 ]].min ()
394
+ + term_3 [i_2 , dic [i_2 ], i_1 , dic [i_1 ]].min (),
395
+ )
396
+ d = d - (1 - attn_1 [dic [i_1 ], i_1 - 1 ].min ()) * c
397
+
398
+ if i_1 == 1 :
399
+ c = term_4 [k , dic [k ], i_1 - 1 , a ].min () + term_3 [i_2 , a , i_1 - 1 , a ]
400
+ # new=c.clone()
401
+ d = d - attn_1 [dic [i_1 ], i_1 - 1 ].min () * c
402
+
403
+ c = torch .min (
404
+ c ,
405
+ term_4 [k , dic [k ], i_1 ][..., dic [i_1 ]].min ()
406
+ + term_3 [i_2 , a , i_1 , dic [i_1 ]].min (),
407
+ )
408
+ d = d - (1 - attn_1 [dic [i_1 ], i_1 - 1 ].min ()) * c
409
+
410
+ # print(d)
411
+
412
+ if type (dic [j ]) == int :
413
+ d = (
414
+ d
415
+ + (
416
+ term_2 [k , dic [k ], j ][..., dic [j ]]
417
+ - term_2 [k , dic [k ], i_1 ][..., dic [i_1 ]].min (dim = - 1 ).values
418
+ ).max ()
419
+ )
420
+
421
+ else :
422
+ d = (
423
+ d
424
+ + (
425
+ term_2 [k , dic [k ], j ][..., dic [j ]].max (dim = - 1 ).values
426
+ - term_2 [k , dic [k ], i_1 ][..., dic [i_1 ]].min (dim = - 1 ).values
427
+ ).max ()
428
+ )
429
+
430
+ if k == 0 :
431
+
432
+ f = d
433
+
434
+ if k != 0 :
435
+ f = torch .max (f , d )
436
+
437
+ if k == i_2 - 1 :
438
+
439
+ g = d .clone ()
440
+
441
+ t_4 = g * attn_1 [dic [i_2 ], i_2 - 1 ]
442
+ t_4 = t_4 + (1 - attn_1 [dic [i_2 ], i_2 - 1 ]) * f
443
+
444
+ return t_4
445
+
446
+
328
447
def least_attention (a , i_1 , i_2 , j , dic , matrices , attn_1 ):
329
448
e = diff_2_4 (a , i_1 , i_2 , j , dic , matrices , attn_1 )
330
449
331
450
return (
332
451
diff_1 (a , i_1 , i_2 , j , dic , matrices )
333
- + diff_3 (a , i_1 , i_2 , j , dic , matrices , attn_1 )
334
452
+ e
453
+ + diff_3 (a , i_1 , i_2 , j , dic , matrices , attn_1 )
335
454
)
336
455
337
456
@@ -587,12 +706,124 @@ def loss_diff_4(b, i_1, i_2, dic, matrices, attn_1, bound_2, n=None):
587
706
return ld_4
588
707
589
708
709
+ def loss_diff_3_4 (b , i_1 , i_2 , dic , matrices , attn_1 , bound_2 , n = None ):
710
+
711
+ (term_0 , term_1 , term_2 , term_3 , term_4 , term_5 , term_6 , term_7 , term_8 ) = matrices
712
+
713
+ if n == b :
714
+ return 0
715
+
716
+ if n is None :
717
+
718
+ n = torch .arange (d_voc )[torch .arange (d_voc ) != b ]
719
+
720
+ for k in range (i_2 + 1 ):
721
+ if k != 0 and k != 1 :
722
+ c = (
723
+ term_8 [k - 1 , dic [k - 1 ]][..., n ]
724
+ - term_8 [k - 1 , dic [k - 1 ], b ].unsqueeze (dim = - 1 )
725
+ ).max ()
726
+ d = c * attn_1 [dic [k ], k - 1 ].min ()
727
+ for i in range (k - 1 ):
728
+ c = torch .max (
729
+ c ,
730
+ (
731
+ term_8 [i , dic [i ]][..., n ]
732
+ - term_8 [i , dic [i ], b ].unsqueeze (dim = - 1 )
733
+ ).max (),
734
+ )
735
+ c = torch .max (
736
+ c ,
737
+ (
738
+ term_8 [k , dic [k ]][..., n ]
739
+ - term_8 [k , dic [k ], b ].unsqueeze (dim = - 1 )
740
+ ).max (),
741
+ )
742
+ d += (1 - attn_1 [dic [k ], k - 1 ].min ()) * c
743
+
744
+ if k == 0 :
745
+ d = (
746
+ term_8 [0 , dic [0 ]][..., n ] - term_8 [0 , dic [0 ], b ].unsqueeze (dim = - 1 )
747
+ ).max ()
748
+
749
+ if k == 1 :
750
+ c = (
751
+ term_8 [0 , dic [0 ]][..., n ] - term_8 [0 , dic [0 ], b ].unsqueeze (dim = - 1 )
752
+ ).max ()
753
+ d = c * attn_1 [dic [k ], k - 1 ].min ()
754
+ c = torch .max (
755
+ c ,
756
+ (
757
+ term_8 [1 , dic [1 ]][..., n ]
758
+ - term_8 [1 , dic [1 ], b ].unsqueeze (dim = - 1 )
759
+ ).max (),
760
+ )
761
+ d += (1 - attn_1 [dic [k ], k - 1 ].min ()) * c
762
+
763
+ d = (
764
+ d
765
+ + (
766
+ term_7 [k , dic [k ]][..., n ] - term_7 [k , dic [k ], b ].unsqueeze (dim = - 1 )
767
+ ).max ()
768
+ )
769
+
770
+ if k == 0 :
771
+ f = d
772
+ if k != 0 :
773
+ f = torch .max (f , d )
774
+ if k == i_1 :
775
+ g = d
776
+ ld_4 = g * (bound_2 [dic [i_2 ], i_2 , i_1 ].min ())
777
+ ld_4 += (1 - bound_2 [dic [i_2 ], i_2 , i_1 ].min ()) * f
778
+ return ld_4
779
+
780
+ for k in range (i_2 + 1 ):
781
+ if k != 0 and k != 1 :
782
+ c = (term_8 [k - 1 , dic [k - 1 ], n ] - term_8 [k - 1 , dic [k - 1 ], b ]).max ()
783
+ d = c * attn_1 [dic [k ], k - 1 ].min ()
784
+ for i in range (k - 1 ):
785
+ c = torch .max (
786
+ c ,
787
+ (term_8 [i , dic [i ], n ] - term_8 [i , dic [i ], b ]).max (),
788
+ )
789
+ c = torch .max (
790
+ c ,
791
+ (term_8 [k , dic [k ], n ] - term_8 [k , dic [k ], b ]).max (),
792
+ )
793
+ d += (1 - attn_1 [dic [k ], k - 1 ].min ()) * c
794
+
795
+ if k == 0 :
796
+ d = (term_8 [0 , dic [0 ], n ] - term_8 [0 , dic [0 ], b ]).max ()
797
+
798
+ if k == 1 :
799
+ c = (term_8 [0 , dic [0 ], n ] - term_8 [0 , dic [0 ], b ]).max ()
800
+ d = c * attn_1 [dic [k ], k - 1 ].min ()
801
+ c = torch .max (
802
+ c ,
803
+ (term_8 [1 , dic [1 ], n ] - term_8 [1 , dic [1 ], b ]).max (),
804
+ )
805
+ d += (1 - attn_1 [dic [k ], k - 1 ].min ()) * c
806
+
807
+ d = d + (term_7 [k , dic [k ], n ] - term_7 [k , dic [k ], b ]).max ()
808
+
809
+ if k == 0 :
810
+ f = d
811
+ if k != 0 :
812
+ f = torch .max (f , d )
813
+ if k == i_1 :
814
+ g = d
815
+ ld_4 = g * (bound_2 [dic [i_2 ], i_2 , i_1 ].min ())
816
+ ld_4 += (1 - bound_2 [dic [i_2 ], i_2 , i_1 ].min ()) * f
817
+ return ld_4
818
+
819
+
590
820
def total_bound (b , i_1 , i_2 , dic , matrices , attn_1 , bound_2 , n = None ):
591
821
return (
592
822
loss_diff_1 (b , i_1 , i_2 , dic , matrices , attn_1 , bound_2 , n )
593
823
+ loss_diff_2 (b , i_1 , i_2 , dic , matrices , attn_1 , bound_2 , n )
594
- + loss_diff_3 (b , i_1 , i_2 , dic , matrices , attn_1 , bound_2 , n )
595
- + loss_diff_4 (b , i_1 , i_2 , dic , matrices , attn_1 , bound_2 , n )
824
+ + loss_diff_3_4 (b , i_1 , i_2 , dic , matrices , attn_1 , bound_2 , n )
825
+ # + loss_diff_3(b, i_1, i_2, dic, matrices, attn_1, bound_2, n)
826
+ # + loss_diff_4(b, i_1, i_2, dic, matrices, attn_1, bound_2, n)
596
827
)
597
828
598
829
0 commit comments