@@ -178,8 +178,8 @@ class GraphElement(ReportElement):
178178 def __init__ (self , env : ReportEnv , rows : int , cols : int , row : Optional [int ] = 0 , col : Optional [int ] = 0 ,
179179 colspan : Optional [int ] = 1 , rowspan : Optional [int ] = 1 , polar : Optional [bool ] = False ):
180180 super ().__init__ (env )
181- self .axes = plt .subplot2grid ((rows , cols ), (row , col ), colspan = colspan , rowspan = rowspan , fig = self . env . figure ,
182- polar = polar )
181+ self .axes = plt .subplot2grid ((rows , cols ), (row , col ), colspan = colspan , rowspan = rowspan ,
182+ fig = self . env . figure , polar = polar )
183183
184184 @abstractmethod
185185 async def render (self , ** kwargs ):
@@ -215,7 +215,8 @@ def _plot(self):
215215 self .env .filename = f'{ uuid .uuid4 ()} .png'
216216 self .env .buffer = BytesIO ()
217217 warnings .filterwarnings ("ignore" , category = UserWarning , message = ".*Glyph.*" )
218- self .env .figure .savefig (self .env .buffer , format = 'png' , bbox_inches = 'tight' )
218+ self .env .figure .savefig (self .env .buffer , format = 'png' , bbox_inches = 'tight' ,
219+ facecolor = self .env .figure .get_facecolor (), transparent = False )
219220 self .env .buffer .seek (0 )
220221
221222 async def _async_plot (self ):
@@ -226,13 +227,16 @@ async def render(self, width: int, height: int, cols: int, rows: int, elements:
226227 facecolor : Optional [str ] = '#2C2F33' ):
227228 plt .style .use ('dark_background' )
228229 plt .rcParams ['axes.facecolor' ] = facecolor
230+ plt .rcParams ['figure.facecolor' ] = facecolor
231+ plt .rcParams ['savefig.facecolor' ] = facecolor
229232 fonts = get_supported_fonts ()
230233 if fonts :
231234 plt .rcParams ['font.family' ] = [f"Noto Sans { x } " for x in fonts ] + ['sans-serif' ]
232235 self .env .figure = plt .figure (figsize = (width , height ))
233236 try :
234237 if facecolor :
235238 self .env .figure .set_facecolor (facecolor )
239+ self .env .figure .patch .set_facecolor (facecolor )
236240 tasks = []
237241 for element in elements :
238242 if 'params' in element :
@@ -243,13 +247,13 @@ async def render(self, width: int, height: int, cols: int, rows: int, elements:
243247 if not element_class and 'type' in element :
244248 element_class = getattr (sys .modules [__name__ ], element ['type' ])
245249 if element_class :
246- # remove parameters, that are not in the class __init__ signature
250+ # remove the parameters that are not in the class __init__ signature
247251 signature = inspect .signature (element_class .__init__ ).parameters .keys ()
248252 class_args = {name : value for name , value in element_args .items () if name in signature }
249253 # instantiate the class
250254 element_class = element_class (self .env , rows , cols , ** class_args )
251255 if isinstance (element_class , (GraphElement , MultiGraphElement )):
252- # remove parameters, that are not in the render methods signature
256+ # remove the parameters that are not in the render methods signature
253257 signature = inspect .signature (element_class .render ).parameters .keys ()
254258 render_args = {name : value for name , value in element_args .items () if name in signature }
255259 tasks .append (asyncio .create_task (element_class .render (** render_args )))
0 commit comments