@@ -419,6 +419,194 @@ def ElementwiseGtIntTensorModule_basic(module, tu: TestUtils):
419419
420420# ==============================================================================
421421
422+ class ElementwiseLtFloatScalarModule (torch .nn .Module ):
423+ def __init__ (self ):
424+ super ().__init__ ()
425+
426+ @export
427+ @annotate_args ([
428+ None ,
429+ ([- 1 , - 1 ], torch .float32 , True ),
430+ ])
431+ def forward (self , x ):
432+ return torch .lt (x , 0.6 )
433+
434+
435+ @register_test_case (module_factory = lambda : ElementwiseLtFloatScalarModule ())
436+ def ElementwiseLtFloatScalarModule_basic (module , tu : TestUtils ):
437+ module .forward (tu .rand (3 , 5 ))
438+
439+
440+ class ElementwiseLtIntScalarModule (torch .nn .Module ):
441+ def __init__ (self ):
442+ super ().__init__ ()
443+
444+ @export
445+ @annotate_args ([
446+ None ,
447+ ([- 1 , - 1 ], torch .int64 , True ),
448+ ])
449+ def forward (self , x ):
450+ return torch .lt (x , 0 )
451+
452+
453+ @register_test_case (module_factory = lambda : ElementwiseLtIntScalarModule ())
454+ def ElementwiseLtIntScalarModule_basic (module , tu : TestUtils ):
455+ module .forward (torch .randint (- 10 , 15 , (3 ,4 )))
456+
457+
458+ class ElementwiseLtDiffWidthScalarModule (torch .nn .Module ):
459+ def __init__ (self ):
460+ super ().__init__ ()
461+
462+ @export
463+ @annotate_args ([
464+ None ,
465+ ([- 1 , - 1 ], torch .int32 , True ),
466+ ])
467+ def forward (self , x ):
468+ return torch .lt (x , 2 )
469+
470+
471+ @register_test_case (module_factory = lambda : ElementwiseLtDiffWidthScalarModule ())
472+ def ElementwiseLtDiffWidthScalarModule_basic (module , tu : TestUtils ):
473+ module .forward (torch .randint (- 10 , 15 , (3 ,4 )).to (torch .int32 ))
474+
475+
476+ class ElementwiseLtFloatTensorModule (torch .nn .Module ):
477+ def __init__ (self ):
478+ super ().__init__ ()
479+
480+ @export
481+ @annotate_args ([
482+ None ,
483+ ([- 1 , - 1 ], torch .float32 , True ),
484+ ([- 1 ], torch .float32 , True ),
485+ ])
486+ def forward (self , x , y ):
487+ return torch .lt (x , y )
488+
489+
490+ @register_test_case (module_factory = lambda : ElementwiseLtFloatTensorModule ())
491+ def ElementwiseLtFloatTensorModule_basic (module , tu : TestUtils ):
492+ module .forward (tu .rand (3 , 5 ), tu .rand (5 ))
493+
494+
495+ class ElementwiseLtIntTensorModule (torch .nn .Module ):
496+ def __init__ (self ):
497+ super ().__init__ ()
498+
499+ @export
500+ @annotate_args ([
501+ None ,
502+ ([- 1 , - 1 ], torch .int64 , True ),
503+ ([- 1 ], torch .int64 , True ),
504+ ])
505+ def forward (self , x , y ):
506+ return torch .lt (x , y )
507+
508+
509+ @register_test_case (module_factory = lambda : ElementwiseLtIntTensorModule ())
510+ def ElementwiseLtIntTensorModule_basic (module , tu : TestUtils ):
511+ module .forward (torch .randint (10 , (3 , 5 )), torch .randint (10 , (5 ,)))
512+
513+ # ==============================================================================
514+
515+ class ElementwiseEqFloatScalarModule (torch .nn .Module ):
516+ def __init__ (self ):
517+ super ().__init__ ()
518+
519+ @export
520+ @annotate_args ([
521+ None ,
522+ ([- 1 , - 1 ], torch .float32 , True ),
523+ ])
524+ def forward (self , x ):
525+ return torch .eq (x , 6.0 )
526+
527+
528+ @register_test_case (module_factory = lambda : ElementwiseEqFloatScalarModule ())
529+ def ElementwiseEqFloatScalarModule_basic (module , tu : TestUtils ):
530+ module .forward (torch .tensor ([[1.0 , 2.2 , 6.0 ], [6.0 , 2.0 , 3.1 ]])
531+ .to (torch .float32 ))
532+
533+
534+ class ElementwiseEqIntScalarModule (torch .nn .Module ):
535+ def __init__ (self ):
536+ super ().__init__ ()
537+
538+ @export
539+ @annotate_args ([
540+ None ,
541+ ([- 1 , - 1 ], torch .int64 , True ),
542+ ])
543+ def forward (self , x ):
544+ return torch .eq (x , 2 )
545+
546+
547+ @register_test_case (module_factory = lambda : ElementwiseEqIntScalarModule ())
548+ def ElementwiseEqIntScalarModule_basic (module , tu : TestUtils ):
549+ module .forward (torch .randint (2 , 4 , (5 ,8 )))
550+
551+
552+ class ElementwiseEqDiffWidthScalarModule (torch .nn .Module ):
553+ def __init__ (self ):
554+ super ().__init__ ()
555+
556+ @export
557+ @annotate_args ([
558+ None ,
559+ ([- 1 , - 1 ], torch .int32 , True ),
560+ ])
561+ def forward (self , x ):
562+ return torch .eq (x , 2 )
563+
564+
565+ @register_test_case (module_factory = lambda : ElementwiseEqDiffWidthScalarModule ())
566+ def ElementwiseEqDiffWidthScalarModule_basic (module , tu : TestUtils ):
567+ module .forward (torch .randint (2 , 4 , (5 ,8 )).to (torch .int32 ))
568+
569+
570+ class ElementwiseEqFloatTensorModule (torch .nn .Module ):
571+ def __init__ (self ):
572+ super ().__init__ ()
573+
574+ @export
575+ @annotate_args ([
576+ None ,
577+ ([- 1 , - 1 ], torch .float32 , True ),
578+ ([- 1 ], torch .float32 , True ),
579+ ])
580+ def forward (self , x , y ):
581+ return torch .eq (x , y )
582+
583+
584+ @register_test_case (module_factory = lambda : ElementwiseEqFloatTensorModule ())
585+ def ElementwiseEqFloatTensorModule_basic (module , tu : TestUtils ):
586+ module .forward (torch .tensor ([[1.0 , 2.2 , 6.0 ], [6.0 , 2.0 , 3.1 ]])
587+ .to (torch .float32 ),
588+ torch .tensor ([1.0 , 2.4 , 6.0 ]).to (torch .float32 ))
589+
590+
591+ class ElementwiseEqIntTensorModule (torch .nn .Module ):
592+ def __init__ (self ):
593+ super ().__init__ ()
594+
595+ @export
596+ @annotate_args ([
597+ None ,
598+ ([- 1 , - 1 ], torch .int64 , True ),
599+ ([- 1 ], torch .int64 , True ),
600+ ])
601+ def forward (self , x , y ):
602+ return torch .eq (x , y )
603+
604+
605+ @register_test_case (module_factory = lambda : ElementwiseEqIntTensorModule ())
606+ def ElementwiseEqIntTensorModule_basic (module , tu : TestUtils ):
607+ module .forward (torch .randint (2 , 4 , (8 , 5 )), torch .randint (2 , 4 , (5 ,)))
608+
609+ # ==============================================================================
422610
423611class ElementwiseClampModule (torch .nn .Module ):
424612 def __init__ (self ):
0 commit comments